36 template <
typename GridwiseGemm,
 
   37           bool HasMainKBlockLoop,
 
   42 #if CK_USE_LAUNCH_BOUNDS 
   48 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) 
   49     if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
 
   51         __shared__ 
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   53         auto splitk_batch_offset = 
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
 
   55         GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
 
   56             karg.p_sorted_token_ids,
 
   57             karg.p_sorted_expert_ids,
 
   59             karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
 
   60             karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
 
   76 template <
typename GridwiseGemm,
 
   77           bool HasMainKBlockLoop,
 
   82 #if CK_USE_LAUNCH_BOUNDS 
   88 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) 
   89     if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
 
   91         __shared__ 
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   92         __shared__ 
char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   94         auto splitk_batch_offset = 
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
 
   96         GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
 
   97             karg.p_sorted_token_ids,
 
   98             karg.p_sorted_expert_ids,
 
  100             karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
 
  101             karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
 
  118 template <
typename ALayout,
 
  124           typename AccDataType,
 
  125           typename CShuffleDataType,
 
  128           typename AElementwiseOperation,
 
  129           typename BElementwiseOperation,
 
  130           typename CElementwiseOperation,
 
  145           typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  146           typename ABlockTransferThreadClusterArrangeOrder,
 
  147           typename ABlockTransferSrcAccessOrder,
 
  148           index_t ABlockTransferSrcVectorDim,
 
  149           index_t ABlockTransferSrcScalarPerVector,
 
  150           index_t ABlockTransferDstScalarPerVector_AK1,
 
  151           bool AThreadTransferSrcResetCoordinateAfterRun,
 
  153           typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  154           typename BBlockTransferThreadClusterArrangeOrder,
 
  155           typename BBlockTransferSrcAccessOrder,
 
  156           index_t BBlockTransferSrcVectorDim,
 
  157           index_t BBlockTransferSrcScalarPerVector,
 
  158           index_t BBlockTransferDstScalarPerVector_BK1,
 
  159           bool BThreadTransferSrcResetCoordinateAfterRun,
 
  161           index_t CShuffleMXdlPerWavePerShuffle,
 
  162           index_t CShuffleNXdlPerWavePerShuffle,
 
  163           typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  164           typename CDEShuffleBlockTransferScalarPerVectors,
 
  167           index_t ActivationOperation                 = 0,
 
  168           bool NSwizzle                               = 
false,
 
  169           bool IsInputGemm                            = 
true,
 
  170           bool MulRoutedWeight                        = 
true,
 
  172           typename ComputeTypeA                       = CDataType,
 
  173           typename ComputeTypeB                       = ComputeTypeA,
 
  174           typename LDSTypeA                           = ADataType,
 
  175           typename LDSTypeB                           = BDataType>
 
  191         CDEShuffleBlockTransferScalarPerVectors{}[
I0];
 
  229                 return static_cast<const DDataType*
>(
nullptr);
 
  256         const index_t gridx  = NSwizzle ? nblock * mblock : nblock;
 
  257         const index_t gridy  = NSwizzle ? 1 : mblock;
 
  287         auto K_t = K_Batch * KPerBlock;
 
  288         return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
 
  293         auto K_t = K_Batch * KPerBlock;
 
  294         return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
 
  299         auto K_t = K_Batch * KPerBlock;
 
  300         return (K + K_t - 1) / K_t * KPerBlock;
 
  306         auto K_t                = K_Batch * KReadVec;
 
  307         return (K + K_t - 1) / K_t * KReadVec;
 
  320     template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, 
typename TileDesc_K0_MN_K1>
 
  336         IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
 
  338         const auto a_grid_desc_mraw_kraw = [&]() {
 
  339             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  343             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  351         if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
 
  352                      GemmSpec == GemmSpecialization::MNKPadding)
 
  355             const auto a_grid_desc_m_k =
 
  369             return a_grid_desc_ak0_m_ak1;
 
  371         else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
 
  372                           GemmSpec == GemmSpecialization::MNPadding)
 
  376                 a_grid_desc_mraw_kraw,
 
  382             return a_grid_desc_ak0_m_ak1;
 
  384         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  385                           GemmSpec == GemmSpecialization::NKPadding)
 
  389                 a_grid_desc_mraw_kraw,
 
  401             return a_grid_desc_ak0_m_ak1;
 
  407                 a_grid_desc_mraw_kraw,
 
  413             return a_grid_desc_ak0_m_ak1;
 
  419         constexpr 
index_t MWave           = MPerBlock / (MXdlPerWave * MPerXdl);
 
  420         constexpr 
index_t WaveSize        = BlockSize / (MWave * 
NWave);
 
  424             make_tuple(
NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, 
I1));
 
  430         const auto b_grid_desc_nraw_kraw = [&]() {
 
  444                         GemmSpec != GemmSpecialization::Default),
 
  445                       "pk_i4_t does not support padding");
 
  447         if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
 
  448                      GemmSpec == GemmSpecialization::MNKPadding)
 
  451             const auto b_grid_desc_n_k =
 
  465             return b_grid_desc_bk0_n_bk1;
 
  467         else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
 
  468                           GemmSpec == GemmSpecialization::MNPadding)
 
  472                 b_grid_desc_nraw_kraw,
 
  478             return b_grid_desc_bk0_n_bk1;
 
  480         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  481                           GemmSpec == GemmSpecialization::MKPadding)
 
  485                 b_grid_desc_nraw_kraw,
 
  497             return b_grid_desc_bk0_n_bk1;
 
  503                 b_grid_desc_nraw_kraw,
 
  509             return b_grid_desc_bk0_n_bk1;
 
  513     template <
typename ABlockDesc_AK0_M_AK1>
 
  514     __host__ __device__ 
static constexpr 
auto 
  517         constexpr 
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
 
  519         return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
 
  522     template <
typename BBlockDesc_BK0_N_BK1>
 
  523     __host__ __device__ 
static constexpr 
auto 
  526         return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
 
  529     template <
typename ELayout>
 
  531         IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
 
  533         const auto c_grid_desc_mraw_nraw = [&]() {
 
  552     template <
typename DLayout>
 
  553     __host__ __device__ 
static auto 
  556         const auto c_grid_desc_mraw_nraw = [&]() {
 
  581                 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
 
  586     template <
typename DsGr
idDesc>
 
  588         const DsGridDesc& ds_grid_desc_m_n, 
index_t MBlock, 
index_t NBlock)
 
  593                     ds_grid_desc_m_n[i], MBlock, NBlock);
 
  609                                     std::array<index_t, NumDTensor> StrideDs_,
 
  635             std::cout << 
"problem {" << 
"NumTokens:" << 
NumTokens << 
", " << 
"TopK:" << 
TopK << 
", " 
  636                       << 
"M:" << 
M << 
", " << 
"N:" << 
N << 
", " << 
"K:" << 
K << 
", " 
  639                       << 
"KRead:" << 
KRead << 
", " << 
"KP:" << 
KPadded << 
", " << 
"AK0:" << 
AK0 
  640                       << 
", " << 
"BK0:" << 
BK0 << 
", " << 
"MBlock: " << 
MBlock << 
", " 
  641                       << 
"NBlock: " << 
NBlock << 
"}" << std::endl;
 
  668                           const index_t* p_sorted_expert_ids_,
 
  669                           const index_t* p_max_token_id_,
 
  670                           const ADataType* p_a_grid_,
 
  671                           const BDataType* p_b_grid_,
 
  672                           std::array<const void*, NumDTensor> p_ds_grid_,
 
  673                           CDataType* p_c_grid_,
 
  681                           std::array<index_t, NumDTensor> StrideDs_,
 
  686                           AElementwiseOperation a_element_op_,
 
  687                           BElementwiseOperation b_element_op_,
 
  688                           CElementwiseOperation c_element_op_)
 
  718                 p_ds_grid(i) = 
static_cast<const DDataType_*
>(p_ds_grid_[i]);
 
  742             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  746             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  751             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
 
  755             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
 
  761             if(k_id < karg.
KBatch - 1)
 
  777         constexpr 
index_t MWave    = MPerBlock / (MXdlPerWave * MPerXdl);
 
  778         constexpr 
index_t WaveSize = BlockSize / (MWave * 
NWave);
 
  780         if constexpr(ABlockLdsExtraM)
 
  790             constexpr 
auto a_lds_block_desc =
 
  802             return a_lds_block_desc_permuted;
 
  809             constexpr 
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
  810             constexpr 
auto M1 = MPerBlock / M0;
 
  812             constexpr 
auto KThreadWrite     = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
  813             constexpr 
auto K0PerThreadWrite = 
AK0Number / KThreadWrite;
 
  814             constexpr 
auto KThreadRead      = WaveSize / MPerXdl;
 
  815             constexpr 
auto K0PerThreadRead  = 
AK0Number / KThreadRead;
 
  817             constexpr 
auto kfold = (
AK1Number * M0 * 
sizeof(LDSTypeA) > 128)
 
  819                                        : 128 / (
AK1Number * M0 * 
sizeof(LDSTypeA));
 
  820             constexpr 
auto KThreadReadPerm =
 
  821                 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
 
  822                     ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
 
  826             constexpr 
auto mpair = (
AK1Number * MPerXdl * 
sizeof(LDSTypeA) > 128)
 
  828                                        : ((128 / (
AK1Number * MPerXdl * 
sizeof(LDSTypeA))) > M0
 
  830                                               : 128 / (
AK1Number * MPerXdl * 
sizeof(LDSTypeA)));
 
  836                            Number<kfold * M0 / mpair>{},
 
  855                 a_lds_block_desc_permuted,
 
  877                 a_lds_block_desc_unmerged,
 
  880                                           Number<KThreadWrite / kfold / KThreadReadPerm>{},
 
  889             return a_lds_block_desc_ak0_m_ak1;
 
  902         constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
  904         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
  911         return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
 
  929                                 ABlockTransferSrcScalarPerVector,
 
  930                                 BBlockTransferSrcScalarPerVector,
 
  952             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
  955         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
  958         constexpr 
auto c_block_size =
 
  959             c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
 
  962                          c_block_size * 
sizeof(CShuffleDataType));
 
  970         static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
 
  971                           (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
 
  972                       "Invalid tuning param!");
 
  980             if(!(karg.M % MPerBlock == 0))
 
  983                 std::cout << 
"Arg M value is not a multiple of MPerBlock! M: " << karg.M << 
" " 
  984                           << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
  998             if(!(karg.N % NPerBlock == 0))
 
 1001                 std::cout << 
"Arg N value is not a multiple of NPerBlock! N: " << karg.N << 
" " 
 1002                           << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1016             auto K_t = karg.KBatch * KPerBlock;
 
 1017             if(!(karg.K % K_t == 0))
 
 1020                 std::cout << 
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " 
 1021                           << karg.K << 
" " << __FILE__ << 
":" << __LINE__
 
 1022                           << 
", in function: " << __func__ << std::endl;
 
 1031             auto K_t                = karg.KBatch * KReadVec;
 
 1033             if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
 
 1041             if(karg.K % ABlockTransferSrcScalarPerVector != 0)
 
 1044                 std::cout << 
"Arg K (" << karg.K
 
 1045                           << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1046                           << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1047                           << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1055             if(karg.M % ABlockTransferSrcScalarPerVector != 0)
 
 1058                 std::cout << 
"Arg M (" << karg.M
 
 1059                           << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1060                           << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1061                           << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1070             if(karg.N % BBlockTransferSrcScalarPerVector != 0)
 
 1073                 std::cout << 
"Arg N (" << karg.N
 
 1074                           << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1075                           << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1076                           << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1084             if(karg.K % BBlockTransferSrcScalarPerVector != 0)
 
 1087                 std::cout << 
"Arg K (" << karg.K
 
 1088                           << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1089                           << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1090                           << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1102                 std::cout << 
"Arg N (" << karg.N
 
 1103                           << 
") value is not a multiple of " 
 1104                              "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1106                           << 
":" << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1117                 std::cout << 
"Arg M (" << karg.M
 
 1118                           << 
") value is not a multiple of " 
 1119                              "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1121                           << 
":" << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1130         const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
 
 1132         if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
 
 1143         const index_t num_loop = K / KPerBlock;
 
 1145         return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
 
 1150         const index_t num_loop = K / KPerBlock;
 
 1152         return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
 
 1155     template <
typename CGr
idDesc>
 
 1157         const CGridDesc& c_grid_desc_m_n, 
index_t MBlock, 
index_t NBlock)
 
 1166         return c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 1174     template <
bool HasMainKBlockLoop,
 
 1178                                const index_t* p_sorted_expert_ids,
 
 1179                                const index_t* p_max_token_id,
 
 1180                                const ADataType* p_a_grid,
 
 1181                                const BDataType* p_b_grid,
 
 1183                                CDataType* p_c_grid,
 
 1188                                AElementwiseOperation a_element_op,
 
 1189                                BElementwiseOperation b_element_op,
 
 1190                                CElementwiseOperation c_element_op)
 
 1202         const auto b_grid_desc_bpreshuffled =
 
 1204         const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
 
 1222         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1225         const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
 
 1227         const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
 
 1228         if(expert_block_id * MPerBlock >= max_token_id)
 
 1231             __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
 
 1232         const auto block_mn = [&]() -> std::pair<int, int> {
 
 1233             if constexpr(NSwizzle)
 
 1235                 const index_t ecnt_prefix  = p_max_token_id[1 + expert_id];
 
 1237                 const index_t ecnt         = p_max_token_id[2 + expert_id] - ecnt_prefix;
 
 1238                 const index_t expert_swizzle =
 
 1239                     ecnt > 0 ? ecnt : 1; 
 
 1240                 const index_t bid_new = blockIdx.x - prefix_block;
 
 1241                 const index_t nid     = __builtin_amdgcn_readfirstlane(
 
 1242                     bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
 
 1244                     __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
 
 1249                 return {blockIdx.x, blockIdx.y};
 
 1252         const index_t block_n_id = block_mn.first;
 
 1253         const index_t block_m_id = block_mn.second;
 
 1255             __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
 
 1258         constexpr 
auto AMThreads  = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
 1259         constexpr 
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
 1260         constexpr 
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
 
 1261         constexpr 
auto AKThreads  = AK0Threads * AK1Threads;
 
 1262         constexpr 
auto AMRepeats  = MPerBlock / AMThreads;
 
 1263         const index_t token_pos   = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
 
 1265         if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
 
 1269             const index_t fused_token = p_sorted_token_ids[token_pos + m0];
 
 1270             index_t token_offset      = fused_token & 0xffffff;
 
 1271             if constexpr(!IsInputGemm)
 
 1273                 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 1275             gather_offsets(m0) = 
static_cast<IndexType
>(token_offset) * problem.
K;
 
 1278             __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
 
 1279         const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
 
 1284         const index_t n_block_data_idx_on_grid =
 
 1285             __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
 
 1287         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1288             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1289         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1291             b_grid_desc_bpreshuffled.GetElementSpaceSize());
 
 1293         const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1294             p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
 
 1295         const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1296             p_b_scale_grid + expert_id * expert_scale_stride,
 
 1297             b_scale_grid_desc_bn_ak.GetElementSpaceSize());
 
 1308             AElementwiseOperation,
 
 1312             ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
 1313             ABlockTransferThreadClusterArrangeOrder,
 
 1316             decltype(a_grid_desc_ak0_m_ak1),
 
 1317             decltype(a_block_desc_ak0_m_ak1),
 
 1318             ABlockTransferSrcAccessOrder,
 
 1320             ABlockTransferSrcVectorDim,
 
 1322             ABlockTransferSrcScalarPerVector,
 
 1323             ABlockTransferDstScalarPerVector_AK1,
 
 1326             AThreadTransferSrcResetCoordinateAfterRun,
 
 1330             BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
 
 1333                                                 a_block_desc_ak0_m_ak1,
 
 1340         auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
 
 1341             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1346             decltype(b_grid_desc_bpreshuffled),
 
 1347             decltype(b_block_desc_bk0_n_bk1),
 
 1351             BBlockTransferSrcScalarPerVector,
 
 1352             BThreadTransferSrcResetCoordinateAfterRun,
 
 1353             true>(b_grid_desc_bpreshuffled,
 
 1361         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1362             static_cast<LDSTypeA*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1368         static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
 
 1370         auto c_thread_buf            = blockwise_gemm_pipeline.GetCThreadBuffer();
 
 1371         decltype(c_thread_buf) c_thread_buf_up;
 
 1373         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
 1374             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
 1377         constexpr 
index_t ScaleSliceSizeM = MXdlPerWave;
 
 1386         constexpr 
index_t MWaves   = MPerBlock / (MXdlPerWave * MPerXdl);
 
 1387         constexpr 
index_t NWaves   = NPerBlock / (NXdlPerWave * NPerXdl);
 
 1388         constexpr 
index_t WaveSize = BlockSize / (MWaves * NWaves);
 
 1400         const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
 
 1402         if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
 
 1407                 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
 
 1408             index_t token_offset = fused_token & 0xffffff;
 
 1409             if constexpr(!IsInputGemm)
 
 1411                 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 1413             scale_gather_offsets(m0) =
 
 1417         auto a_scale_thread_copy =
 
 1420                                                     decltype(a_scale_grid_desc_am_ak),
 
 1421                                                     decltype(a_scale_thread_desc),
 
 1431         auto b_scale_thread_copy =
 
 1434                                              decltype(b_scale_grid_desc_bn_ak),
 
 1435                                              decltype(b_scale_thread_desc),
 
 1442                 b_scale_grid_desc_bn_ak, 
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
 
 1445         constexpr 
auto a_scale_thread_slice_copy_step =
 
 1447         constexpr 
auto b_scale_thread_slice_copy_step = 
make_multi_index(0, ScaleSliceSizeK);
 
 1450         if constexpr(IsInputGemm)
 
 1452             const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / 
BPackedSize;
 
 1453             const auto b_grid_buf_up     = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1455                 b_grid_desc_bpreshuffled.GetElementSpaceSize());
 
 1459                 decltype(b_grid_desc_bpreshuffled),
 
 1460                 decltype(b_block_desc_bk0_n_bk1),
 
 1464                 BBlockTransferSrcScalarPerVector,
 
 1465                 BThreadTransferSrcResetCoordinateAfterRun,
 
 1466                 true>(b_grid_desc_bpreshuffled,
 
 1472                 p_b_scale_grid + expert_scale_stride / 2 / 
BPackedSize;
 
 1473             const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1474                 p_b_scale_grid_up + expert_id * expert_scale_stride,
 
 1475                 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
 
 1476             auto b_scale_thread_copy_up =
 
 1479                                                  decltype(b_scale_grid_desc_bn_ak),
 
 1480                                                  decltype(b_scale_thread_desc),
 
 1487                     b_scale_grid_desc_bn_ak,
 
 1490             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
 
 1491                 a_grid_desc_ak0_m_ak1,
 
 1492                 a_block_desc_ak0_m_ak1,
 
 1496                 a_block_slice_copy_step,
 
 1498                 b_grid_desc_bpreshuffled,
 
 1499                 b_block_desc_bk0_n_bk1,
 
 1501                 b_blockwise_copy_up,
 
 1505                 b_block_slice_copy_step,
 
 1507                 c_scale_thread_desc,
 
 1511                 a_scale_grid_desc_am_ak,
 
 1512                 a_scale_thread_desc,
 
 1513                 a_scale_thread_copy,
 
 1515                 a_scale_thread_slice_copy_step,
 
 1517                 b_scale_grid_desc_bn_ak,
 
 1518                 b_scale_thread_desc,
 
 1519                 b_scale_thread_copy,
 
 1520                 b_scale_thread_copy_up,
 
 1522                 b_scale_grid_buf_up,
 
 1523                 b_scale_thread_slice_copy_step,
 
 1525                 num_k_block_main_loop);
 
 1529             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
 
 1530                 a_grid_desc_ak0_m_ak1,
 
 1531                 a_block_desc_ak0_m_ak1,
 
 1535                 a_block_slice_copy_step,
 
 1537                 b_grid_desc_bpreshuffled,
 
 1538                 b_block_desc_bk0_n_bk1,
 
 1542                 b_block_slice_copy_step,
 
 1544                 c_scale_thread_desc,
 
 1547                 a_scale_grid_desc_am_ak,
 
 1548                 a_scale_thread_desc,
 
 1549                 a_scale_thread_copy,
 
 1551                 a_scale_thread_slice_copy_step,
 
 1553                 b_scale_grid_desc_bn_ak,
 
 1554                 b_scale_thread_desc,
 
 1555                 b_scale_thread_copy,
 
 1557                 b_scale_thread_slice_copy_step,
 
 1559                 num_k_block_main_loop);
 
 1564             static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
 1565                               NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
 1568             constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 1572             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
 
 1573                 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
 
 1577             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
 
 1578                 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
 
 1580             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I0);
 
 1581             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I1);
 
 1582             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I2);
 
 1583             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I3);
 
 1584             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I4);
 
 1585             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I5);
 
 1586             constexpr 
auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I6);
 
 1587             constexpr 
auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I7);
 
 1589             static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
 
 1590             static_assert(M0 * M1 * M2 == MPerBlock);
 
 1591             static_assert(N4 == 4 || N4 == 8);
 
 1598                     if constexpr(MulRoutedWeight)
 
 1600                         const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
 
 1601                         topk_weight         = p_ds_grid[
I0][m_pos];
 
 1606                                 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
 
 1609                             if constexpr(IsInputGemm) 
 
 1613                                     float gate = c_thread_buf[cidx];
 
 1614                                     float up   = c_thread_buf_up[cidx];
 
 1615                                     if constexpr(MulRoutedWeight)
 
 1617                                         gate = gate * topk_weight;
 
 1618                                         up   = up * topk_weight;
 
 1626                                     c_thread_buf(cidx) = gate * up;
 
 1630                                     float gate = c_thread_buf[cidx];
 
 1631                                     float up   = c_thread_buf_up[cidx];
 
 1632                                     if constexpr(MulRoutedWeight)
 
 1634                                         gate = gate * topk_weight;
 
 1635                                         up   = up * topk_weight;
 
 1643                                     c_thread_buf(cidx) = gate * up;
 
 1648                                 if constexpr(MulRoutedWeight)
 
 1650                                     c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
 
 1658             constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1661             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1662                 static_cast<CShuffleDataType*
>(p_shared),
 
 1663                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1666                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 1686             const auto c_thread_mtx_on_block =
 
 1687                 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
 1689             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
 1690             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
 1692             const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
 
 1698             const auto m_thread_data_on_block_idx =
 
 1699                 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
 
 1702             const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
 
 1708             const auto n_thread_data_on_block_idx =
 
 1709                 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
 
 1713             auto c_thread_copy_vgpr_to_lds =
 
 1716                                                    decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
 
 1717                                                    decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
 
 1719                                                    Sequence<CShuffleMXdlPerWavePerShuffle,
 
 1720                                                             CShuffleNXdlPerWavePerShuffle,
 
 1733                     c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
 
 1736                                      m_thread_data_on_block_idx[
I1],
 
 1737                                      n_thread_data_on_block_idx[
I1],
 
 1738                                      m_thread_data_on_block_idx[
I2],
 
 1739                                      n_thread_data_on_block_idx[
I2],
 
 1740                                      n_thread_data_on_block_idx[
I3],
 
 1741                                      n_thread_data_on_block_idx[
I4]),
 
 1744             using EDataType = CDataType;
 
 1749             const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1756                     const DDataType* ptr_ = p_ds_grid[i];
 
 1759                     return make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1760                         ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
 
 1766                 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
 1768                              { 
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
 
 1773                 tie(c_shuffle_block_buf),
 
 1775                              { 
return ds_grid_buf[i]; },
 
 1779             const auto idx_c_ds_block_begin =
 
 1789             const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1790                 c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 1792             using CDEBlockTransferCluster =
 
 1793                 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
 
 1794             const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
 
 1795             constexpr 
index_t scatter_weight_idx  = IsInputGemm ? 1 : 1; 
 
 1800                    decltype(c_ds_desc_refs),
 
 1801                    decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
 
 1802                    CElementwiseOperation,
 
 1806                             CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 1808                             CShuffleNXdlPerWavePerShuffle * 
NWave * NPerXdl>, 
 
 1809                    CDEBlockTransferCluster,
 
 1815                    CDEShuffleBlockTransferScalarPerVectors,
 
 1827                      idx_c_ds_block_begin,
 
 1828                      tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1832             auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1833                 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1835             constexpr 
auto sfc_c_vgpr =
 
 1838                                   Sequence<CShuffleMXdlPerWavePerShuffle,
 
 1839                                            CShuffleNXdlPerWavePerShuffle,
 
 1847             constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
 1850             constexpr 
auto sfc_cde_block =
 
 1854                                            CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 1856                                            CShuffleNXdlPerWavePerShuffle * 
NWave * NPerXdl>>{};
 
 1858             static_assert(num_access == sfc_cde_block.GetNumOfAccess(), 
"wrong!");
 
 1859             constexpr 
auto EMThreads =
 
 1860                 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
 
 1861             constexpr 
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
 
 1862             constexpr 
auto ENThreads =
 
 1863                 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
 
 1868                 auto dstidx = sfc_cde_block.GetIndex(access_id);
 
 1870                     block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
 
 1872                     const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
 
 1873                     index_t token_offset      = fused_token & 0xffffff;
 
 1874                     if constexpr(IsInputGemm)
 
 1876                         token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 1878                     scatter_offsets(m0) = token_offset * problem.
N;
 
 1884                 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
 
 1885                                               sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
 1887                                               c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
 
 1888                                               c_shuffle_block_buf);
 
 1894                 cde_block_copy_lds_and_global.Run(
 
 1897                     tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1901                 if constexpr(access_id < num_access - 1)
 
 1903                     constexpr 
auto cde_lds_and_global_step =
 
 1904                         sfc_cde_block.GetForwardStep(access_id);
 
 1908                         cde_block_copy_lds_and_global.MoveSrcSliceWindow(
 
 1909                             c_ds_desc_refs, i + 
I1, cde_lds_and_global_step);
 
 1913                     cde_block_copy_lds_and_global.MoveDstSliceWindow(
 
 1914                         tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1916                         cde_lds_and_global_step);
 
 1922     template <
bool HasMainKBlockLoop,
 
 1926                                     const index_t* p_sorted_expert_ids,
 
 1927                                     const index_t* p_max_token_id,
 
 1928                                     const ADataType* p_a_grid,
 
 1929                                     const BDataType* p_b_grid,
 
 1931                                     CDataType* p_c_grid,
 
 1937                                     AElementwiseOperation a_element_op,
 
 1938                                     BElementwiseOperation b_element_op,
 
 1939                                     CElementwiseOperation c_element_op)
 
 1951         const auto b_grid_desc_bpreshuffled =
 
 1953         const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
 
 1970         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1973         const index_t max_token_id    = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
 
 1974         const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
 
 1975         if(expert_block_id * MPerBlock >= max_token_id)
 
 1978             __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
 
 1979         const auto block_mn = [&]() -> std::pair<int, int> {
 
 1980             if constexpr(NSwizzle)
 
 1982                 const index_t ecnt_prefix    = p_max_token_id[1 + expert_id];
 
 1984                 const index_t ecnt           = p_max_token_id[2 + expert_id] - ecnt_prefix;
 
 1985                 const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
 
 1986                 const index_t bid_new        = blockIdx.x - prefix_block;
 
 1987                 const index_t nid            = __builtin_amdgcn_readfirstlane(
 
 1988                     bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
 
 1990                     __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
 
 1995                 return {blockIdx.x, blockIdx.y};
 
 1998         const index_t block_n_id = block_mn.first;
 
 1999         const index_t block_m_id = block_mn.second;
 
 2002             __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
 
 2005         constexpr 
auto AMThreads  = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
 2006         constexpr 
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
 2007         constexpr 
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
 
 2008         constexpr 
auto AKThreads  = AK0Threads * AK1Threads;
 
 2009         constexpr 
auto AMRepeats  = MPerBlock / AMThreads;
 
 2010         const index_t token_pos   = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
 
 2012         if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
 
 2018             const index_t fused_token = p_sorted_token_ids[token_pos + m0];
 
 2019             index_t token_offset      = fused_token & 0xffffff;
 
 2020             if constexpr(!IsInputGemm)
 
 2022                 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 2024             gather_offsets(m0) = 
static_cast<IndexType
>(token_offset) * problem.
K;
 
 2027             __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
 
 2028         const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
 
 2032         const index_t n_block_data_idx_on_grid =
 
 2033             __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
 
 2035         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2036             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
 2037         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2039             b_grid_desc_bpreshuffled.GetElementSpaceSize());
 
 2041         const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2042             p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
 
 2043         const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2044             p_b_scale_grid + expert_id * expert_scale_stride,
 
 2045             b_scale_grid_desc_bn_ak.GetElementSpaceSize());
 
 2056             AElementwiseOperation,
 
 2060             ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
 2061             ABlockTransferThreadClusterArrangeOrder,
 
 2064             decltype(a_grid_desc_ak0_m_ak1),
 
 2065             decltype(a_block_desc_ak0_m_ak1),
 
 2066             ABlockTransferSrcAccessOrder,
 
 2068             ABlockTransferSrcVectorDim,
 
 2070             ABlockTransferSrcScalarPerVector,
 
 2071             ABlockTransferDstScalarPerVector_AK1,
 
 2074             AThreadTransferSrcResetCoordinateAfterRun,
 
 2078             BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
 
 2081                                                 a_block_desc_ak0_m_ak1,
 
 2088         auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
 
 2089             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 2090         auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
 
 2091             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 2092         auto b_block_bufs = 
make_tuple(b_block_buf_ping, b_block_buf_pong);
 
 2097             decltype(b_grid_desc_bpreshuffled),
 
 2098             decltype(b_block_desc_bk0_n_bk1),
 
 2102             BBlockTransferSrcScalarPerVector,
 
 2103             BThreadTransferSrcResetCoordinateAfterRun,
 
 2104             true>(b_grid_desc_bpreshuffled,
 
 2112         auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2113             static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 2114         auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2115             static_cast<ADataType*
>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 2116         auto a_block_bufs = 
make_tuple(a_block_buf_ping, a_block_buf_pong);
 
 2122         static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
 
 2124         auto c_thread_buf            = blockwise_gemm_pipeline.GetCThreadBuffer();
 
 2125         decltype(c_thread_buf) c_thread_buf_up;
 
 2127         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
 2128             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
 2132         constexpr 
index_t ScaleSliceSizeM = MXdlPerWave;
 
 2141         constexpr 
index_t MWaves   = MPerBlock / (MXdlPerWave * MPerXdl);
 
 2142         constexpr 
index_t NWaves   = NPerBlock / (NXdlPerWave * NPerXdl);
 
 2143         constexpr 
index_t WaveSize = BlockSize / (MWaves * NWaves);
 
 2155         const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
 
 2157         if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
 
 2162                 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
 
 2163             index_t token_offset = fused_token & 0xffffff;
 
 2164             if constexpr(!IsInputGemm)
 
 2166                 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 2168             scale_gather_offsets(m0) = 
static_cast<IndexType
>(token_offset) *
 
 2172         auto a_scale_thread_copy =
 
 2175                                                     decltype(a_scale_grid_desc_am_ak),
 
 2176                                                     decltype(a_scale_thread_desc),
 
 2186         auto b_scale_thread_copy =
 
 2189                                              decltype(b_scale_grid_desc_bn_ak),
 
 2190                                              decltype(b_scale_thread_desc),
 
 2197                 b_scale_grid_desc_bn_ak, 
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
 
 2200         constexpr 
auto a_scale_thread_slice_copy_step =
 
 2202         constexpr 
auto b_scale_thread_slice_copy_step = 
make_multi_index(0, ScaleSliceSizeK);
 
 2205         if constexpr(IsInputGemm)
 
 2207             const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / 
BPackedSize;
 
 2208             const auto b_grid_buf_up     = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2210                 b_grid_desc_bpreshuffled.GetElementSpaceSize());
 
 2214                 decltype(b_grid_desc_bpreshuffled),
 
 2215                 decltype(b_block_desc_bk0_n_bk1),
 
 2219                 BBlockTransferSrcScalarPerVector,
 
 2220                 BThreadTransferSrcResetCoordinateAfterRun,
 
 2221                 true>(b_grid_desc_bpreshuffled,
 
 2227                 p_b_scale_grid + expert_scale_stride / 2 / 
BPackedSize;
 
 2228             const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2229                 p_b_scale_grid_up + expert_id * expert_scale_stride / 
BPackedSize,
 
 2230                 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
 
 2231             auto b_scale_thread_copy_up =
 
 2234                                                  decltype(b_scale_grid_desc_bn_ak),
 
 2235                                                  decltype(b_scale_thread_desc),
 
 2242                     b_scale_grid_desc_bn_ak,
 
 2245             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
 
 2246                 a_grid_desc_ak0_m_ak1,
 
 2247                 a_block_desc_ak0_m_ak1,
 
 2251                 a_block_slice_copy_step,
 
 2252                 b_grid_desc_bpreshuffled,
 
 2253                 b_block_desc_bk0_n_bk1,
 
 2255                 b_blockwise_copy_up,
 
 2259                 b_block_slice_copy_step,
 
 2260                 c_scale_thread_desc,
 
 2263                 a_scale_grid_desc_am_ak,
 
 2264                 a_scale_thread_desc,
 
 2265                 a_scale_thread_copy,
 
 2267                 a_scale_thread_slice_copy_step,
 
 2268                 b_scale_grid_desc_bn_ak,
 
 2269                 b_scale_thread_desc,
 
 2270                 b_scale_thread_copy,
 
 2271                 b_scale_thread_copy_up,
 
 2273                 b_scale_grid_buf_up,
 
 2274                 b_scale_thread_slice_copy_step,
 
 2275                 num_k_block_main_loop);
 
 2279             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
 
 2280                 a_grid_desc_ak0_m_ak1,
 
 2281                 a_block_desc_ak0_m_ak1,
 
 2285                 a_block_slice_copy_step,
 
 2286                 b_grid_desc_bpreshuffled,
 
 2287                 b_block_desc_bk0_n_bk1,
 
 2291                 b_block_slice_copy_step,
 
 2292                 c_scale_thread_desc,
 
 2294                 a_scale_grid_desc_am_ak,
 
 2295                 a_scale_thread_desc,
 
 2296                 a_scale_thread_copy,
 
 2298                 a_scale_thread_slice_copy_step,
 
 2299                 b_scale_grid_desc_bn_ak,
 
 2300                 b_scale_thread_desc,
 
 2301                 b_scale_thread_copy,
 
 2303                 b_scale_thread_slice_copy_step,
 
 2304                 num_k_block_main_loop);
 
 2310             static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
 2311                               NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
 2314             constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 2318             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
 
 2319                 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
 
 2323             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
 
 2324                 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
 
 2326             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I0);
 
 2327             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I1);
 
 2328             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I2);
 
 2329             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I3);
 
 2330             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I4);
 
 2331             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I5);
 
 2332             constexpr 
auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I6);
 
 2333             constexpr 
auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I7);
 
 2335             static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
 
 2336             static_assert(M0 * M1 * M2 == MPerBlock);
 
 2337             static_assert(N4 == 4 || N4 == 8);
 
 2344                     if constexpr(MulRoutedWeight)
 
 2346                         const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
 
 2347                         topk_weight         = p_ds_grid[
I0][m_pos];
 
 2352                                 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
 
 2355                             if constexpr(IsInputGemm) 
 
 2359                                     float gate = c_thread_buf[cidx];
 
 2360                                     float up   = c_thread_buf_up[cidx];
 
 2361                                     if constexpr(MulRoutedWeight)
 
 2363                                         gate = gate * topk_weight;
 
 2364                                         up   = up * topk_weight;
 
 2372                                     c_thread_buf(cidx) = gate * up;
 
 2376                                     float gate = c_thread_buf[cidx];
 
 2377                                     float up   = c_thread_buf_up[cidx];
 
 2378                                     if constexpr(MulRoutedWeight)
 
 2380                                         gate = gate * topk_weight;
 
 2381                                         up   = up * topk_weight;
 
 2389                                     c_thread_buf(cidx) = gate * up;
 
 2394                                 if constexpr(MulRoutedWeight)
 
 2396                                     c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
 
 2405             constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 2408             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2409                 static_cast<CShuffleDataType*
>(p_shared),
 
 2410                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 2413                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 2433             const auto c_thread_mtx_on_block =
 
 2434                 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
 2436             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
 2437             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
 2439             const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
 
 2445             const auto m_thread_data_on_block_idx =
 
 2446                 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
 
 2449             const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
 
 2455             const auto n_thread_data_on_block_idx =
 
 2456                 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
 
 2460             auto c_thread_copy_vgpr_to_lds =
 
 2463                                                    decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
 
 2464                                                    decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
 
 2466                                                    Sequence<CShuffleMXdlPerWavePerShuffle,
 
 2467                                                             CShuffleNXdlPerWavePerShuffle,
 
 2480                     c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
 
 2483                                      m_thread_data_on_block_idx[
I1],
 
 2484                                      n_thread_data_on_block_idx[
I1],
 
 2485                                      m_thread_data_on_block_idx[
I2],
 
 2486                                      n_thread_data_on_block_idx[
I2],
 
 2487                                      n_thread_data_on_block_idx[
I3],
 
 2488                                      n_thread_data_on_block_idx[
I4]),
 
 2491             using EDataType = CDataType;
 
 2496             const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
 
 2502                     return make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2503                         p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
 
 2509                 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
 2511                              { 
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
 
 2516                 tie(c_shuffle_block_buf),
 
 2518                              { 
return ds_grid_buf[i]; },
 
 2522             const auto idx_c_ds_block_begin =
 
 2532             const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
 
 2533                 c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 2535             using CDEBlockTransferCluster =
 
 2536                 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
 
 2537             const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
 
 2538             constexpr 
index_t scatter_weight_idx  = IsInputGemm ? 1 : 1; 
 
 2543                    decltype(c_ds_desc_refs),
 
 2544                    decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
 
 2545                    CElementwiseOperation,
 
 2549                             CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2551                             CShuffleNXdlPerWavePerShuffle * 
NWave * NPerXdl>, 
 
 2552                    CDEBlockTransferCluster,
 
 2558                    CDEShuffleBlockTransferScalarPerVectors,
 
 2570                      idx_c_ds_block_begin,
 
 2571                      tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2575             auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2576                 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 2578             constexpr 
auto sfc_c_vgpr =
 
 2581                                   Sequence<CShuffleMXdlPerWavePerShuffle,
 
 2582                                            CShuffleNXdlPerWavePerShuffle,
 
 2590             constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
 2593             constexpr 
auto sfc_cde_block =
 
 2597                                            CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2599                                            CShuffleNXdlPerWavePerShuffle * 
NWave * NPerXdl>>{};
 
 2601             static_assert(num_access == sfc_cde_block.GetNumOfAccess(), 
"wrong!");
 
 2602             constexpr 
auto EMThreads =
 
 2603                 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
 
 2604             constexpr 
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
 
 2605             constexpr 
auto ENThreads =
 
 2606                 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
 
 2612                 auto dstidx = sfc_cde_block.GetIndex(access_id);
 
 2614                     block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
 
 2616                     const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
 
 2617                     index_t token_offset      = fused_token & 0xffffff;
 
 2618                     if constexpr(IsInputGemm)
 
 2620                         token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 2622                     scatter_offsets(m0) = token_offset * problem.
N;
 
 2628                 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
 
 2629                                               sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
 2631                                               c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
 
 2632                                               c_shuffle_block_buf);
 
 2638                 cde_block_copy_lds_and_global.Run(
 
 2641                     tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2645                 if constexpr(access_id < num_access - 1)
 
 2647                     constexpr 
auto cde_lds_and_global_step =
 
 2648                         sfc_cde_block.GetForwardStep(access_id);
 
 2652                         cde_block_copy_lds_and_global.MoveSrcSliceWindow(
 
 2653                             c_ds_desc_refs, i + 
I1, cde_lds_and_global_step);
 
 2657                     cde_block_copy_lds_and_global.MoveDstSliceWindow(
 
 2658                         tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2660                         cde_lds_and_global_step);
 
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
 
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
 
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
GemmSpecialization
Definition: gemm_specialization.hpp:11
 
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
 
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
 
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
 
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
 
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
 
InMemoryDataOperationEnum
Definition: ck.hpp:277
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
 
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
 
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
 
int64_t long_index_t
Definition: ck.hpp:300
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
 
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
 
Activation
Definition: gridwise_moe_gemm.hpp:31
 
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
 
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
 
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
 
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
 
constexpr bool is_same_v
Definition: type.hpp:283
 
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:299
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
 
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp:37
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
 
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
 
Definition: gridwise_moe_gemm_blockscale.hpp:666
 
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm_blockscale.hpp:722
 
CDataType * p_c_grid
Definition: gridwise_moe_gemm_blockscale.hpp:728
 
const BScaleType * p_b_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:731
 
const index_t * p_max_token_id
Definition: gridwise_moe_gemm_blockscale.hpp:724
 
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm_blockscale.hpp:727
 
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:735
 
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm_blockscale.hpp:667
 
const ADataType * p_a_grid
Definition: gridwise_moe_gemm_blockscale.hpp:725
 
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:733
 
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm_blockscale.hpp:723
 
const BDataType * p_b_grid
Definition: gridwise_moe_gemm_blockscale.hpp:726
 
const AScaleType * p_a_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:730
 
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:734
 
Definition: gridwise_moe_gemm_blockscale.hpp:601
 
index_t K
Definition: gridwise_moe_gemm_blockscale.hpp:648
 
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm_blockscale.hpp:602
 
index_t TopK
Definition: gridwise_moe_gemm_blockscale.hpp:645
 
index_t NPadded
Definition: gridwise_moe_gemm_blockscale.hpp:655
 
index_t StrideB
Definition: gridwise_moe_gemm_blockscale.hpp:650
 
__host__ void Print() const
Definition: gridwise_moe_gemm_blockscale.hpp:633
 
index_t BK0
Definition: gridwise_moe_gemm_blockscale.hpp:659
 
index_t KRead
Definition: gridwise_moe_gemm_blockscale.hpp:656
 
index_t N
Definition: gridwise_moe_gemm_blockscale.hpp:647
 
index_t StrideC
Definition: gridwise_moe_gemm_blockscale.hpp:652
 
index_t KBatch
Definition: gridwise_moe_gemm_blockscale.hpp:653
 
index_t MBlock
Definition: gridwise_moe_gemm_blockscale.hpp:660
 
index_t KPadded
Definition: gridwise_moe_gemm_blockscale.hpp:657
 
index_t NumTokens
Definition: gridwise_moe_gemm_blockscale.hpp:644
 
index_t StrideA
Definition: gridwise_moe_gemm_blockscale.hpp:649
 
index_t AK0
Definition: gridwise_moe_gemm_blockscale.hpp:658
 
index_t M
Definition: gridwise_moe_gemm_blockscale.hpp:646
 
index_t MPadded
Definition: gridwise_moe_gemm_blockscale.hpp:654
 
index_t NBlock
Definition: gridwise_moe_gemm_blockscale.hpp:661
 
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm_blockscale.hpp:651
 
Definition: gridwise_moe_gemm_blockscale.hpp:739
 
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm_blockscale.hpp:740
 
index_t a_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:771
 
index_t b_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:772
 
Definition: gridwise_moe_gemm_blockscale.hpp:177
 
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:530
 
static constexpr index_t KPack
Definition: gridwise_moe_gemm_blockscale.hpp:202
 
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm_blockscale.hpp:893
 
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm_blockscale.hpp:417
 
static constexpr auto AK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:195
 
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:275
 
static constexpr auto BK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:196
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm_blockscale.hpp:236
 
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1925
 
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1141
 
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm_blockscale.hpp:775
 
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:266
 
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:291
 
static constexpr auto I3
Definition: gridwise_moe_gemm_blockscale.hpp:184
 
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:968
 
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:321
 
static constexpr index_t KLane
Definition: gridwise_moe_gemm_blockscale.hpp:215
 
remove_cvref_t< decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm_blockscale.hpp:942
 
float AScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:178
 
static constexpr auto AK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:193
 
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm_blockscale.hpp:217
 
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm_blockscale.hpp:335
 
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:1156
 
static constexpr auto I2
Definition: gridwise_moe_gemm_blockscale.hpp:183
 
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:238
 
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:271
 
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:280
 
static constexpr auto I4
Definition: gridwise_moe_gemm_blockscale.hpp:185
 
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm_blockscale.hpp:575
 
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:261
 
static constexpr auto I6
Definition: gridwise_moe_gemm_blockscale.hpp:187
 
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm_blockscale.hpp:190
 
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1148
 
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:524
 
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:252
 
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm_blockscale.hpp:234
 
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm_blockscale.hpp:900
 
float BScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:179
 
static constexpr auto I7
Definition: gridwise_moe_gemm_blockscale.hpp:188
 
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:310
 
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:587
 
static constexpr index_t NWave
Definition: gridwise_moe_gemm_blockscale.hpp:219
 
static constexpr index_t NLane
Definition: gridwise_moe_gemm_blockscale.hpp:218
 
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm_blockscale.hpp:199
 
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:303
 
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:315
 
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm_blockscale.hpp:944
 
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:297
 
static constexpr auto I0
Definition: gridwise_moe_gemm_blockscale.hpp:181
 
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:515
 
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:554
 
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm_blockscale.hpp:197
 
static constexpr index_t KGroup
Definition: gridwise_moe_gemm_blockscale.hpp:204
 
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:245
 
static constexpr auto I5
Definition: gridwise_moe_gemm_blockscale.hpp:186
 
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1177
 
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm_blockscale.hpp:223
 
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm_blockscale.hpp:427
 
static constexpr auto I1
Definition: gridwise_moe_gemm_blockscale.hpp:182
 
static constexpr auto BK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:194
 
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm_blockscale.hpp:221
 
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:285
 
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))> DsGridDesc_M_N
Definition: gridwise_moe_gemm_blockscale.hpp:598
 
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
 
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1810
 
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
 
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1804
 
Definition: sequence.hpp:43
 
Definition: tensor_space_filling_curve.hpp:20
 
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
 
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Definition: threadwise_tensor_slice_transfer.hpp:440
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
Definition: tuple.hpp:117
 
Definition: amd_ck_fp8.hpp:36
 
Definition: integral_constant.hpp:20
 
Definition: data_type.hpp:187
 
Definition: functional2.hpp:33
 
Definition: device_base.hpp:197
 
Definition: unary_element_wise_operation.hpp:1041
 
Definition: unary_element_wise_operation.hpp:340
 
Definition: unary_element_wise_operation.hpp:1087