38 template <
typename GridwiseGemm,
 
   39           bool HasMainKBlockLoop,
 
   44 #if CK_USE_LAUNCH_BOUNDS 
   51     if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
 
   53         __shared__ 
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   55         auto splitk_batch_offset = 
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
 
   57         GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
 
   58             karg.p_sorted_token_ids,
 
   59             karg.p_sorted_expert_ids,
 
   61             karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
 
   62             karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
 
   63             karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
 
   64             karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
 
   78 template <
typename GridwiseGemm,
 
   79           bool HasMainKBlockLoop,
 
   84 #if CK_USE_LAUNCH_BOUNDS 
   91     if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
 
   93         __shared__ 
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   94         __shared__ 
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   96         auto splitk_batch_offset = 
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
 
   98         GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
 
   99             karg.p_sorted_token_ids,
 
  100             karg.p_sorted_expert_ids,
 
  102             karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
 
  103             karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
 
  104             karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
 
  105             karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
 
  120 template <
typename ALayout,
 
  125           typename AScaleDataType,
 
  127           typename BScaleDataType,
 
  128           typename AccDataType,
 
  129           typename CShuffleDataType,
 
  132           typename AElementwiseOperation,
 
  133           typename BElementwiseOperation,
 
  134           typename CElementwiseOperation,
 
  147           typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  148           typename ABlockTransferThreadClusterArrangeOrder,
 
  149           typename ABlockTransferSrcAccessOrder,
 
  150           index_t ABlockTransferSrcVectorDim,
 
  151           index_t ABlockTransferSrcScalarPerVector,
 
  152           index_t ABlockTransferDstScalarPerVector_AK1,
 
  153           bool AThreadTransferSrcResetCoordinateAfterRun,
 
  155           typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  156           typename BBlockTransferThreadClusterArrangeOrder,
 
  157           typename BBlockTransferSrcAccessOrder,
 
  158           index_t BBlockTransferSrcVectorDim,
 
  159           index_t BBlockTransferSrcScalarPerVector,
 
  160           index_t BBlockTransferDstScalarPerVector_BK1,
 
  161           bool BThreadTransferSrcResetCoordinateAfterRun,
 
  163           index_t CShuffleMXdlPerWavePerShuffle,
 
  164           index_t CShuffleNXdlPerWavePerShuffle,
 
  165           typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  166           typename CDEShuffleBlockTransferScalarPerVectors,
 
  169           index_t ActivationOperation                 = 0,
 
  170           bool NSwizzle                               = 
false,
 
  171           bool IsInputGemm                            = 
true,
 
  172           bool MulRoutedWeight                        = 
true,
 
  174           typename ComputeTypeA                       = ADataType,
 
  175           typename ComputeTypeB                       = BDataType>
 
  193         CDEShuffleBlockTransferScalarPerVectors{}[
I0];
 
  240                   "A scale pack data type too large!");
 
  242                   "B scale pack data type too large!");
 
  250                 return static_cast<const DDataType*
>(
nullptr);
 
  263         const index_t gridx  = NSwizzle ? nblock * mblock : nblock;
 
  264         const index_t gridy  = NSwizzle ? 1 : mblock;
 
  295         auto K_t = K_Batch * KPerBlock;
 
  296         return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
 
  301         auto K_t = K_Batch * KPerBlock;
 
  302         return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
 
  307         auto K_t = K_Batch * KPerBlock;
 
  308         return (K + K_t - 1) / K_t * KPerBlock;
 
  314         auto K_t                = K_Batch * KReadVec;
 
  315         return (K + K_t - 1) / K_t * KReadVec;
 
  328     template <
index_t MNXdlPerWave,
 
  333               typename TileDesc_K0_MN_K1>
 
  376         IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
 
  378         const auto a_grid_desc_mraw_kraw = [&]() {
 
  379             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  383             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  391         if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
 
  392                      GemmSpec == GemmSpecialization::MNKPadding)
 
  395             const auto a_grid_desc_m_k =
 
  409             return a_grid_desc_ak0_m_ak1;
 
  411         else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
 
  412                           GemmSpec == GemmSpecialization::MNPadding)
 
  416                 a_grid_desc_mraw_kraw,
 
  422             return a_grid_desc_ak0_m_ak1;
 
  424         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  425                           GemmSpec == GemmSpecialization::NKPadding)
 
  429                 a_grid_desc_mraw_kraw,
 
  441             return a_grid_desc_ak0_m_ak1;
 
  447                 a_grid_desc_mraw_kraw,
 
  454                 a_grid_desc_ak0_m_ak1,
 
  462                 a_grid_desc_permuted,
 
  476         constexpr 
index_t MWave           = MPerBlock / (MXdlPerWave * MPerXdl);
 
  477         constexpr 
index_t WaveSize        = BlockSize / (MWave * 
NWave);
 
  486         const auto b_grid_desc_nraw_kraw = [&]() {
 
  500                         GemmSpec != GemmSpecialization::Default),
 
  501                       "pk_i4_t does not support padding");
 
  503                         GemmSpec != GemmSpecialization::Default),
 
  504                       "f4x2_pk_t does not support padding");
 
  506         if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
 
  507                      GemmSpec == GemmSpecialization::MNKPadding)
 
  510             const auto b_grid_desc_n_k =
 
  524             return b_grid_desc_bk0_n_bk1;
 
  526         else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
 
  527                           GemmSpec == GemmSpecialization::MNPadding)
 
  531                 b_grid_desc_nraw_kraw,
 
  537             return b_grid_desc_bk0_n_bk1;
 
  539         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  540                           GemmSpec == GemmSpecialization::MKPadding)
 
  544                 b_grid_desc_nraw_kraw,
 
  556             return b_grid_desc_bk0_n_bk1;
 
  562                 b_grid_desc_nraw_kraw,
 
  569                 b_grid_desc_bk0_n_bk1,
 
  577                 b_grid_desc_permuted,
 
  589     template <
typename ABlockDesc_AK0_M_AK1>
 
  590     __host__ __device__ 
static constexpr 
auto 
  593         constexpr 
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
 
  595         return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl, true>(
 
  596             ABlockDesc_AK0_M_AK1{});
 
  599     template <
typename BBlockDesc_BK0_N_BK1>
 
  600     __host__ __device__ 
static constexpr 
auto 
  603         constexpr 
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
 
  605         return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl, false>(
 
  606             BBlockDesc_BK0_N_BK1{});
 
  609     template <
typename ELayout>
 
  611         IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
 
  613         const auto c_grid_desc_mraw_nraw = [&]() {
 
  632     template <
typename DLayout>
 
  633     __host__ __device__ 
static auto 
  636         const auto c_grid_desc_mraw_nraw = [&]() {
 
  661                 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
 
  666     template <
typename DsGr
idDesc>
 
  668         const DsGridDesc& ds_grid_desc_m_n, 
index_t MBlock, 
index_t NBlock)
 
  673                     ds_grid_desc_m_n[i], MBlock, NBlock);
 
  689                          std::array<index_t, NumDTensor> StrideDs_,
 
  717             std::cout << 
"problem {" << 
"NumTokens:" << 
NumTokens << 
", " << 
"TopK:" << 
TopK << 
", " 
  718                       << 
"M:" << 
M << 
", " << 
"N:" << 
N << 
", " << 
"K:" << 
K << 
", " 
  722                       << 
", " << 
"KRead:" << 
KRead << 
", " << 
"KP:" << 
KPadded << 
", " 
  723                       << 
"AK0:" << 
AK0 << 
", " << 
"BK0:" << 
BK0 << 
", " << 
"MBlock: " << 
MBlock 
  724                       << 
", " << 
"NBlock: " << 
NBlock << 
"}" << std::endl;
 
  753                           const index_t* p_sorted_expert_ids_,
 
  754                           const index_t* p_max_token_id_,
 
  755                           const ADataType* p_a_grid_,
 
  756                           const AScaleDataType* p_a_scale_grid_,
 
  757                           const BDataType* p_b_grid_,
 
  758                           const BScaleDataType* p_b_scale_grid_,
 
  759                           std::array<const void*, NumDTensor> p_ds_grid_,
 
  760                           CDataType* p_c_grid_,
 
  770                           std::array<index_t, NumDTensor> StrideDs_,
 
  773                           AElementwiseOperation a_element_op_,
 
  774                           BElementwiseOperation b_element_op_,
 
  775                           CElementwiseOperation c_element_op_)
 
  807                 p_ds_grid(i) = 
static_cast<const DDataType_*
>(p_ds_grid_[i]);
 
  830             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  834             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  839             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
 
  843             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
 
  857             if(k_id < karg.
KBatch - 1)
 
  875         constexpr 
index_t MWave    = MPerBlock / (MXdlPerWave * MPerXdl);
 
  876         constexpr 
index_t WaveSize = BlockSize / (MWave * 
NWave);
 
  890             constexpr 
auto a_lds_block_desc =
 
  902             return a_lds_block_desc_permuted;
 
  909             constexpr 
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
  910             constexpr 
auto M1 = MPerBlock / M0;
 
  912             constexpr 
auto KThreadWrite     = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
  913             constexpr 
auto K0PerThreadWrite = 
AK0Number / KThreadWrite;
 
  914             constexpr 
auto KThreadRead      = WaveSize / MPerXdl;
 
  915             constexpr 
auto K0PerThreadRead  = 
AK0Number / KThreadRead;
 
  917             constexpr 
auto kfold = (
AK1Number * M0 * 
sizeof(ADataType) > 128)
 
  919                                        : 128 / (
AK1Number * M0 * 
sizeof(ADataType));
 
  920             constexpr 
auto KThreadReadPerm =
 
  921                 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
 
  922                     ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
 
  926             constexpr 
auto mpair = (
AK1Number * MPerXdl * 
sizeof(ADataType) > 128)
 
  928                                        : ((128 / (
AK1Number * MPerXdl * 
sizeof(ADataType))) > M0
 
  930                                               : 128 / (
AK1Number * MPerXdl * 
sizeof(ADataType)));
 
  936                            Number<kfold * M0 / mpair>{},
 
  955                 a_lds_block_desc_permuted,
 
  977                 a_lds_block_desc_unmerged,
 
  980                                           Number<KThreadWrite / kfold / KThreadReadPerm>{},
 
  989             return a_lds_block_desc_ak0_m_ak1;
 
 1005         constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 1007         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1014         return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
 
 1035                                 ABlockTransferSrcScalarPerVector,
 
 1036                                 BBlockTransferSrcScalarPerVector,
 
 1055             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
 1058         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1061         constexpr 
auto c_block_size =
 
 1062             c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
 
 1064         return math::max(a_block_space_size_aligned * 
sizeof(ADataType),
 
 1065                          c_block_size * 
sizeof(CShuffleDataType));
 
 1073         static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
 
 1074                           (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
 
 1075                       "Invalid tuning param!");
 
 1077         static_assert(KPerBlock % (ScaleBlockSize / 
BPackedSize) == 0,
 
 1078                       "KPerBlock should be multiple of ScaleBlockSize");
 
 1086             if(!(karg.M % MPerBlock == 0))
 
 1090                     std::cout << 
"Arg M value is not a multiple of MPerBlock! M: " << karg.M << 
" " 
 1091                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1104             if(!(karg.N % NPerBlock == 0))
 
 1108                     std::cout << 
"Arg N value is not a multiple of NPerBlock! N: " << karg.N << 
" " 
 1109                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1121             auto K_t = karg.KBatch * KPerBlock;
 
 1122             if(!(karg.K % K_t == 0))
 
 1126                     std::cout << 
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " 
 1127                               << karg.K << 
" " << __FILE__ << 
":" << __LINE__
 
 1128                               << 
", in function: " << __func__ << std::endl;
 
 1136             auto K_t                = karg.KBatch * KReadVec;
 
 1138             if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
 
 1146             if(karg.K % ABlockTransferSrcScalarPerVector != 0)
 
 1150                     std::cout << 
"Arg K (" << karg.K
 
 1151                               << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1152                               << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1153                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1160             if(karg.M % ABlockTransferSrcScalarPerVector != 0)
 
 1164                     std::cout << 
"Arg M (" << karg.M
 
 1165                               << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1166                               << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1167                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1175             if(karg.N % BBlockTransferSrcScalarPerVector != 0)
 
 1179                     std::cout << 
"Arg N (" << karg.N
 
 1180                               << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1181                               << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1182                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1189             if(karg.K % BBlockTransferSrcScalarPerVector != 0)
 
 1193                     std::cout << 
"Arg K (" << karg.K
 
 1194                               << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1195                               << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1196                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1208                     std::cout << 
"Arg N (" << karg.N
 
 1209                               << 
") value is not a multiple of " 
 1210                                  "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1212                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1224                     std::cout << 
"Arg M (" << karg.M
 
 1225                               << 
") value is not a multiple of " 
 1226                                  "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1228                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1238         const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
 
 1240         if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
 
 1251         const index_t num_loop = K / KPerBlock;
 
 1253         return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
 
 1258         const index_t num_loop = K / KPerBlock;
 
 1260         return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
 
 1263     template <
typename CGr
idDesc>
 
 1265         const CGridDesc& c_grid_desc_m_n, 
index_t MBlock, 
index_t NBlock)
 
 1274         return c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 1283     template <
bool HasMainKBlockLoop,
 
 1286     __device__ 
static void Run(
const index_t* p_sorted_token_ids,
 
 1287                                const index_t* p_sorted_expert_ids,
 
 1288                                const index_t* p_max_token_id,
 
 1289                                const ADataType* p_a_grid,
 
 1290                                const AScaleDataType* p_a_scale_grid,
 
 1291                                const BDataType* p_b_grid,
 
 1292                                const BScaleDataType* p_b_scale_grid,
 
 1294                                CDataType* p_c_grid,
 
 1296                                const Problem& problem,
 
 1297                                AElementwiseOperation a_element_op,
 
 1298                                BElementwiseOperation b_element_op,
 
 1299                                CElementwiseOperation c_element_op)
 
 1305             IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
 
 1311         const auto b_grid_desc_bpreshuffled =
 
 1313         const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
 
 1314             IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
 
 1332         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1334                 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
 
 1335         const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
 
 1337         const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
 
 1338         if(expert_block_id * MPerBlock >= max_token_id)
 
 1341             __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
 
 1343         const auto block_mn = [&]() -> std::pair<int, int> {
 
 1344             if constexpr(NSwizzle)
 
 1346                 const index_t ecnt_prefix  = p_max_token_id[1 + expert_id];
 
 1347                 const index_t prefix_block = ecnt_prefix * problem.NBlock;
 
 1348                 const index_t ecnt         = p_max_token_id[2 + expert_id] - ecnt_prefix;
 
 1349                 const index_t expert_swizzle =
 
 1350                     ecnt > 0 ? ecnt : 1; 
 
 1351                 const index_t bid_new = blockIdx.x - prefix_block;
 
 1352                 const index_t nid     = __builtin_amdgcn_readfirstlane(
 
 1353                     bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
 
 1355                     __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
 
 1360                 return {blockIdx.x, blockIdx.y};
 
 1364         const index_t block_n_id = block_mn.first;
 
 1365         const index_t block_m_id = block_mn.second;
 
 1367             __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
 
 1370         constexpr 
auto AMThreads  = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
 1371         constexpr 
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
 1372         constexpr 
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
 
 1373         constexpr 
auto AKThreads  = AK0Threads * AK1Threads;
 
 1374         constexpr 
auto AMRepeats  = MPerBlock / AMThreads;
 
 1375         const index_t token_pos   = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
 
 1377         if(token_pos >= max_token_id || token0 >= problem.NumTokens)
 
 1379         StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
 
 1380         static_for<0, AMRepeats, 1>{}([&](
auto m0) {
 
 1381             const index_t fused_token = p_sorted_token_ids[token_pos + m0];
 
 1382             index_t token_offset      = fused_token & 0xffffff;
 
 1383             if constexpr(!IsInputGemm)
 
 1385                 token_offset = token_offset * problem.TopK + (fused_token >> 24);
 
 1387             gather_offsets(m0) = 
static_cast<IndexType
>(token_offset) * problem.K / 
APackedSize;
 
 1390             __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
 
 1391         const index_t expert_scale_stride =
 
 1392             __builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) *
 
 1396         const index_t n_block_data_idx_on_grid =
 
 1397             __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
 
 1399         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1400             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1401         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1402             p_b_grid + expert_id * expert_stride / 
BPackedSize,
 
 1403             b_grid_desc_bpreshuffled.GetElementSpaceSize());
 
 1406         const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1407             p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
 
 1408         const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1409             p_b_scale_grid + expert_id * expert_scale_stride,
 
 1410             b_scale_grid_desc_bn_ak.GetElementSpaceSize());
 
 1419         auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
 
 1421             AElementwiseOperation,
 
 1424             Sequence<AK0Number, MPerBlock, AK1Number>,
 
 1425             ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
 1426             ABlockTransferThreadClusterArrangeOrder,
 
 1429             decltype(a_grid_desc_ak0_m_ak1),
 
 1430             decltype(a_block_desc_ak0_m_ak1),
 
 1431             ABlockTransferSrcAccessOrder,
 
 1433             ABlockTransferSrcVectorDim,
 
 1435             ABlockTransferSrcScalarPerVector,
 
 1436             ABlockTransferDstScalarPerVector_AK1,
 
 1439             AThreadTransferSrcResetCoordinateAfterRun,
 
 1443             BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
 
 1446                                                 a_block_desc_ak0_m_ak1,
 
 1453         auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
 
 1454             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1456         auto b_blockwise_copy =
 
 1457             ThreadwiseTensorSliceTransfer_v2<BDataType,
 
 1459                                              decltype(b_grid_desc_bpreshuffled),
 
 1460                                              decltype(b_block_desc_bk0_n_bk1),
 
 1465                                                       Number<BK1Value>{}>,
 
 1466                                              Sequence<1, 2, 0, 3>,
 
 1468                                              BBlockTransferSrcScalarPerVector,
 
 1469                                              BThreadTransferSrcResetCoordinateAfterRun,
 
 1471                 b_grid_desc_bpreshuffled,
 
 1479         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1481             a_block_desc_ak0_m_ak1.GetElementSpaceSize() / 
APackedSize);
 
 1487         static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
 
 1489         auto c_thread_buf            = blockwise_gemm_pipeline.GetCThreadBuffer();
 
 1490         decltype(c_thread_buf) c_thread_buf_up;
 
 1494                                   c_thread_buf.num_of_v_,
 
 1495                                   c_thread_buf.s_per_v,
 
 1499         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
 1500             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
 1504         const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
 
 1505         const auto waveId_m = wave_idx[
I0];
 
 1506         const auto waveId_n = wave_idx[
I1];
 
 1508         static constexpr 
auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
 
 1510         auto thread_offset_shuffled =
 
 1513         auto a_thread_offset_m = waveId_m;
 
 1515         auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
 
 1518             decltype(a_scale_grid_desc_am_ak),
 
 1519             decltype(BlockwiseGemmPipe::a_scale_thread_desc),
 
 1525             true>(a_scale_grid_desc_am_ak,
 
 1531         auto b_thread_offset_n = waveId_n;
 
 1533         auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
 
 1536             decltype(b_scale_grid_desc_bn_ak),
 
 1537             decltype(BlockwiseGemmPipe::b_scale_thread_desc),
 
 1543             true>(b_scale_grid_desc_bn_ak,
 
 1548         if constexpr(IsInputGemm)
 
 1550             const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / 
BPackedSize;
 
 1551             const auto b_grid_buf_up     = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1552                 p_b_grid_up + expert_id * expert_stride / 
BPackedSize,
 
 1553                 b_grid_desc_bpreshuffled.GetElementSpaceSize());
 
 1554             auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
 
 1557                 decltype(b_grid_desc_bpreshuffled),
 
 1558                 decltype(b_block_desc_bk0_n_bk1),
 
 1559                 Sequence<Number<NXdlPerWave>{}, 
I1, Number<KRepeat>{}, Number<BK1Value>{}>,
 
 1560                 Sequence<1, 2, 0, 3>,
 
 1562                 BBlockTransferSrcScalarPerVector,
 
 1563                 BThreadTransferSrcResetCoordinateAfterRun,
 
 1564                 true>(b_grid_desc_bpreshuffled,
 
 1569             const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
 
 1570             const auto b_scale_grid_buf_up          = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1571                 p_b_scale_grid_up + expert_id * expert_scale_stride,
 
 1572                 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
 
 1573             auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
 
 1576                 decltype(b_scale_grid_desc_bn_ak),
 
 1577                 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
 
 1584                 b_scale_grid_desc_bn_ak,
 
 1589             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
 
 1590                 a_grid_desc_ak0_m_ak1,
 
 1591                 a_block_desc_ak0_m_ak1,
 
 1595                 a_block_slice_copy_step,
 
 1596                 b_grid_desc_bpreshuffled,
 
 1597                 b_block_desc_bk0_n_bk1,
 
 1599                 b_blockwise_copy_up,
 
 1603                 b_block_slice_copy_step,
 
 1606                 a_scale_grid_desc_am_ak,
 
 1607                 a_scale_thread_copy,
 
 1609                 b_scale_grid_desc_bn_ak,
 
 1610                 b_scale_thread_copy,
 
 1611                 b_scale_thread_copy_up,
 
 1613                 b_scale_grid_buf_up,
 
 1614                 num_k_block_main_loop);
 
 1618             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
 
 1619                 a_grid_desc_ak0_m_ak1,
 
 1620                 a_block_desc_ak0_m_ak1,
 
 1624                 a_block_slice_copy_step,
 
 1625                 b_grid_desc_bpreshuffled,
 
 1626                 b_block_desc_bk0_n_bk1,
 
 1630                 b_block_slice_copy_step,
 
 1632                 a_scale_grid_desc_am_ak,
 
 1633                 a_scale_thread_copy,
 
 1635                 b_scale_grid_desc_bn_ak,
 
 1636                 b_scale_thread_copy,
 
 1638                 num_k_block_main_loop);
 
 1643             static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
 1644                               NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
 1648             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
 1649                 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 1653             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
 
 1654                 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 1656             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
 
 1657             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
 
 1658             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
 
 1659             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
 
 1660             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
 
 1661             constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
 
 1662             constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
 
 1663             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
 
 1666             static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
 
 1667             static_assert(M4 == 4);
 
 1671             vector_type<float, 4> topk_weights; 
 
 1672             static_for<0, NXdlPerWave, 1>{}([&](
auto n0) {
 
 1673                 static_for<0, MXdlPerWave, 1>{}([&](
auto m0) { 
 
 1674                     static_for<0, M2, 1>{}([&](
auto m2) {      
 
 1675                         const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
 
 1676                                               m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
 
 1677                         if constexpr(MulRoutedWeight)
 
 1679                             topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
 
 1680                                 p_ds_grid[
I2] + m_pos);
 
 1682                         static_for<0, M4, 1>{}([&](
auto m4) { 
 
 1684                                 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
 
 1686                             constexpr 
auto cidx = Number<c_offset>{};
 
 1688                             if constexpr(IsInputGemm) 
 
 1692                                     float gate = c_thread_buf[cidx];
 
 1693                                     float up   = c_thread_buf_up[cidx];
 
 1694                                     if constexpr(MulRoutedWeight)
 
 1696                                         gate = gate * topk_weights.AsType<
float>()[m4];
 
 1697                                         up   = up * topk_weights.AsType<
float>()[m4];
 
 1699                                     tensor_operation::element_wise::Silu{}(gate, gate);
 
 1700                                     c_thread_buf_fp32(cidx) = gate * up;
 
 1704                                     float gate = c_thread_buf[cidx];
 
 1705                                     float up   = c_thread_buf_up[cidx];
 
 1706                                     if constexpr(MulRoutedWeight)
 
 1708                                         gate = gate * topk_weights.AsType<
float>()[m4];
 
 1709                                         up   = up * topk_weights.AsType<
float>()[m4];
 
 1711                                     tensor_operation::element_wise::Gelu{}(gate, gate);
 
 1712                                     c_thread_buf_fp32(cidx) = gate * up;
 
 1717                                 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
 
 1718                                 if constexpr(MulRoutedWeight)
 
 1720                                     c_thread_buf_fp32(cidx) =
 
 1721                                         topk_weights.AsType<
float>()[m4] * c_thread_buf_fp32[cidx];
 
 1729             constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1732             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1733                 static_cast<CShuffleDataType*
>(p_shared),
 
 1734                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1737                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 1741                         Number<CShuffleMXdlPerWavePerShuffle>{}, 
 
 1748                         Number<CShuffleNXdlPerWavePerShuffle>{}, 
 
 1751                 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
 
 1753                     Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
 
 1757             const auto c_thread_mtx_on_block =
 
 1758                 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
 1760             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
 1761             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
 1763             const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
 1769             const auto m_thread_data_on_block_idx =
 
 1770                 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
 1773             const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
 1779             const auto n_thread_data_on_block_idx =
 
 1780                 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
 1784             auto c_thread_copy_vgpr_to_lds =
 
 1785                 ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
 
 1787                                                    decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 1788                                                    decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 1790                                                    Sequence<CShuffleMXdlPerWavePerShuffle,
 
 1791                                                             CShuffleNXdlPerWavePerShuffle,
 
 1798                                                    Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
 
 1804                     c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1807                                      m_thread_data_on_block_idx[
I1],
 
 1808                                      n_thread_data_on_block_idx[
I1],
 
 1809                                      m_thread_data_on_block_idx[
I2],
 
 1810                                      m_thread_data_on_block_idx[
I3],
 
 1811                                      m_thread_data_on_block_idx[
I4],
 
 1812                                      n_thread_data_on_block_idx[
I2]),
 
 1815             using EDataType = CDataType;
 
 1818                 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
 
 1820             const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1822                     ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
 
 1826                     return make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1827                         p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
 
 1829                 Number<NumDTensor>{});
 
 1833                 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
 1835                              { 
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
 
 1836                              Number<NumDTensor>{}));
 
 1840                 tie(c_shuffle_block_buf),
 
 1842                              { 
return ds_grid_buf[i]; },
 
 1843                              Number<NumDTensor>{}));
 
 1846             const auto idx_c_ds_block_begin =
 
 1854                                      Number<NumDTensor>{}));
 
 1856             const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1857                 c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 1859             using CDEBlockTransferCluster =
 
 1860                 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
 
 1861             const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
 
 1862             constexpr 
index_t scatter_weight_idx  = 1; 
 
 1863             auto cde_block_copy_lds_and_global    = ThreadGroupTensorSliceTransfer_v7r3_scatter<
 
 1867                    decltype(c_ds_desc_refs),
 
 1868                    decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
 
 1869                    CElementwiseOperation,
 
 1870                    Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, 
 
 1873                             CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 1875                             CShuffleNXdlPerWavePerShuffle * 
NWave * NPerXdl>, 
 
 1876                    CDEBlockTransferCluster,
 
 1877                    Sequence<0, 1, 2, 3>, 
 
 1878                    Sequence<0, 1, 2, 3>, 
 
 1879                    Sequence<0, 1, 2, 3>, 
 
 1882                    CDEShuffleBlockTransferScalarPerVectors,
 
 1894                      idx_c_ds_block_begin,
 
 1895                      tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1899             auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1900                 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1901             constexpr 
auto sfc_c_vgpr =
 
 1902                 SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
 
 1903                                   Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
 
 1904                                   Sequence<CShuffleMXdlPerWavePerShuffle,
 
 1905                                            CShuffleNXdlPerWavePerShuffle,
 
 1913             constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
 1916             constexpr 
auto sfc_cde_block =
 
 1917                 SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
 
 1918                                   Sequence<0, 2, 1, 3>,
 
 1920                                            CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 1922                                            CShuffleNXdlPerWavePerShuffle * 
NWave * NPerXdl>>{};
 
 1924             static_assert(num_access == sfc_cde_block.GetNumOfAccess(), 
"wrong!");
 
 1925             constexpr 
auto EMThreads =
 
 1926                 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
 
 1927             constexpr 
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
 
 1928             constexpr 
auto ENThreads =
 
 1929                 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
 
 1930             static_for<0, num_access, 1>{}([&](
auto access_id) {
 
 1932                 StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
 
 1934                 auto dstidx = sfc_cde_block.GetIndex(access_id);
 
 1936                     block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
 
 1937                 static_for<0, EMRepeats, 1>{}([&](
auto m0) {
 
 1938                     const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
 
 1939                     IndexType token_offset    = fused_token & 0xffffff;
 
 1940                     if constexpr(IsInputGemm)
 
 1942                         token_offset = token_offset * problem.TopK + (fused_token >> 24);
 
 1944                     scatter_offsets(m0) = 
static_cast<IndexType
>(token_offset) * problem.N;
 
 1950                 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1951                                               sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
 1953                                               c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1954                                               c_shuffle_block_buf);
 
 1960                 cde_block_copy_lds_and_global.Run(
 
 1963                     tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1967                 if constexpr(access_id < num_access - 1)
 
 1969                     constexpr 
auto cde_lds_and_global_step =
 
 1970                         sfc_cde_block.GetForwardStep(access_id);
 
 1973                     static_for<0, NumDTensor, 1>{}([&](
auto i) {
 
 1974                         cde_block_copy_lds_and_global.MoveSrcSliceWindow(
 
 1975                             c_ds_desc_refs, i + 
I1, cde_lds_and_global_step);
 
 1979                     cde_block_copy_lds_and_global.MoveDstSliceWindow(
 
 1980                         tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1982                         cde_lds_and_global_step);
 
 1989     template <
bool HasMainKBlockLoop,
 
 1993                                     const index_t* p_sorted_expert_ids,
 
 1994                                     const index_t* p_max_token_id,
 
 1995                                     const ADataType* p_a_grid,
 
 1996                                     const AScaleDataType* p_a_scale_grid,
 
 1997                                     const BDataType* p_b_grid,
 
 1998                                     const BScaleDataType* p_b_scale_grid,
 
 2000                                     CDataType* p_c_grid,
 
 2004                                     AElementwiseOperation a_element_op,
 
 2005                                     BElementwiseOperation b_element_op,
 
 2006                                     CElementwiseOperation c_element_op)
 
 2019         const auto b_grid_desc_bpreshuffled =
 
 2021         const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
 
 2029         const auto Padded_Scale_M =
 
 2053         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
 2057         const index_t max_token_id    = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
 
 2058         const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
 
 2059         if(expert_block_id * MPerBlock >= max_token_id)
 
 2062             __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
 
 2063         const auto block_mn = [&]() -> std::pair<int, int> {
 
 2064             if constexpr(NSwizzle)
 
 2066                 const index_t ecnt_prefix  = p_max_token_id[1 + expert_id];
 
 2068                 const index_t ecnt         = p_max_token_id[2 + expert_id] - ecnt_prefix;
 
 2069                 const index_t expert_swizzle =
 
 2070                     ecnt > 0 ? ecnt : 1; 
 
 2071                 const index_t bid_new = blockIdx.x - prefix_block;
 
 2072                 const index_t nid     = __builtin_amdgcn_readfirstlane(
 
 2073                     bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
 
 2075                     __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
 
 2080                 return {blockIdx.x, blockIdx.y};
 
 2084         const index_t block_n_id = block_mn.first;
 
 2085         const index_t block_m_id = block_mn.second;
 
 2087             __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
 
 2090         constexpr 
auto AMThreads  = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
 2091         constexpr 
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
 2092         constexpr 
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
 
 2093         constexpr 
auto AKThreads  = AK0Threads * AK1Threads;
 
 2094         constexpr 
auto AMRepeats  = MPerBlock / AMThreads;
 
 2095         const index_t token_pos   = block_m_id * MPerBlock + threadIdx.x / AKThreads;
 
 2097         if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
 
 2101             const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
 
 2102             index_t token_offset      = fused_token & 0xffffff;
 
 2103             if constexpr(!IsInputGemm)
 
 2105                 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 2107             gather_offsets(m0) = 
static_cast<IndexType
>(token_offset) * problem.
K;
 
 2111             __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
 
 2112         const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
 
 2113             problem.
N * (IsInputGemm ? 2 : 1) *
 
 2117         const index_t n_block_data_idx_on_grid =
 
 2118             __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / 
NXdlPack);
 
 2121         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2122             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
 2123         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2124             p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
 
 2127         const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2128             p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
 
 2129         const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2130             p_b_scale_grid + (expert_id * expert_scale_stride) / 
sizeof(BScaleDataType),
 
 2131             b_scale_grid_desc_bn_ak.GetElementSpaceSize());
 
 2143             ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
 2144             ABlockTransferThreadClusterArrangeOrder,
 
 2147             decltype(a_grid_desc_ak0_m_ak1),
 
 2148             decltype(a_block_desc_ak0_m_ak1),
 
 2149             ABlockTransferSrcAccessOrder,
 
 2150             ABlockTransferSrcVectorDim,
 
 2152             ABlockTransferSrcScalarPerVector,
 
 2154             1>(a_grid_desc_ak0_m_ak1,
 
 2156                a_block_desc_ak0_m_ak1,
 
 2162         auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
 
 2163             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 2164         auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
 
 2165             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 2166         auto b_block_bufs = 
make_tuple(b_block_buf_ping, b_block_buf_pong);
 
 2168         auto b_blockwise_copy =
 
 2171                                              decltype(b_grid_desc_bpreshuffled),
 
 2172                                              decltype(b_block_desc_bk0_n_bk1),
 
 2180                                              BBlockTransferSrcScalarPerVector,
 
 2181                                              BThreadTransferSrcResetCoordinateAfterRun,
 
 2183                 b_grid_desc_bpreshuffled,
 
 2192         auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2193             static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 2194         auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2195             static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 2196         auto a_block_bufs = 
make_tuple(a_block_buf_ping, a_block_buf_pong);
 
 2202         static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
 
 2204         auto c_thread_buf            = blockwise_gemm_pipeline.GetCThreadBuffer();
 
 2205         decltype(c_thread_buf) c_thread_buf_up;
 
 2209                                   c_thread_buf.num_of_v_,
 
 2210                                   c_thread_buf.s_per_v,
 
 2214         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
 2215             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
 2219         const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
 
 2220         const auto waveId_m = wave_idx[
I0];
 
 2221         const auto waveId_n = wave_idx[
I1];
 
 2223         auto thread_offset_shuffled =
 
 2226         auto a_thread_offset_m = waveId_m;
 
 2229         const index_t token_scale_pos = block_m_id * MPerBlock;
 
 2230         if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
 
 2236             decltype(a_scale_grid_desc_am_ak),
 
 2237             decltype(BlockwiseGemmPipe::a_scale_thread_desc),
 
 2243             true>(a_scale_grid_desc_am_ak,
 
 2249         auto b_thread_offset_n = waveId_n;
 
 2254             decltype(b_scale_grid_desc_bn_ak),
 
 2255             decltype(BlockwiseGemmPipe::b_scale_thread_desc),
 
 2261             true>(b_scale_grid_desc_bn_ak,
 
 2266         if constexpr(IsInputGemm)
 
 2268             const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
 
 2269             const auto b_grid_buf_up     = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2270                 p_b_grid_up + expert_id * expert_stride,
 
 2271                 b_grid_desc_bpreshuffled.GetElementSpaceSize());
 
 2272             auto b_blockwise_copy_up =
 
 2275                                                  decltype(b_grid_desc_bpreshuffled),
 
 2276                                                  decltype(b_block_desc_bk0_n_bk1),
 
 2284                                                  BBlockTransferSrcScalarPerVector,
 
 2285                                                  BThreadTransferSrcResetCoordinateAfterRun,
 
 2287                     b_grid_desc_bpreshuffled,
 
 2293             const BScaleDataType* p_b_scale_grid_up =
 
 2294                 p_b_scale_grid + expert_scale_stride / 2 / 
sizeof(BScaleDataType);
 
 2295             const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2296                 p_b_scale_grid_up + expert_id * expert_scale_stride / 
sizeof(BScaleDataType),
 
 2297                 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
 
 2302                 decltype(b_scale_grid_desc_bn_ak),
 
 2303                 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
 
 2310                 b_scale_grid_desc_bn_ak,
 
 2315             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
 
 2317                 a_grid_desc_ak0_m_ak1,
 
 2318                 a_block_desc_ak0_m_ak1,
 
 2322                 a_block_slice_copy_step,
 
 2324                 b_grid_desc_bpreshuffled,
 
 2325                 b_block_desc_bk0_n_bk1,
 
 2327                 b_blockwise_copy_up,
 
 2331                 b_block_slice_copy_step,
 
 2336                 a_scale_grid_desc_am_ak,
 
 2337                 a_scale_thread_copy,
 
 2340                 b_scale_grid_desc_bn_ak,
 
 2341                 b_scale_thread_copy,
 
 2342                 b_scale_thread_copy_up,
 
 2344                 b_scale_grid_buf_up,
 
 2345                 num_k_block_main_loop);
 
 2349             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
 
 2350                 a_grid_desc_ak0_m_ak1, 
 
 2351                 a_block_desc_ak0_m_ak1,
 
 2355                 a_block_slice_copy_step,
 
 2356                 b_grid_desc_bpreshuffled, 
 
 2357                 b_block_desc_bk0_n_bk1,
 
 2361                 b_block_slice_copy_step,
 
 2363                 a_scale_grid_desc_am_ak, 
 
 2364                 a_scale_thread_copy,
 
 2366                 b_scale_grid_desc_bn_ak, 
 
 2367                 b_scale_thread_copy,
 
 2369                 num_k_block_main_loop);
 
 2374             static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
 2375                               NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
 2377             static_assert(CShuffleMXdlPerWavePerShuffle % 
MXdlPack == 0 &&
 
 2378                               CShuffleNXdlPerWavePerShuffle % 
NXdlPack == 0,
 
 2381             constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 2384             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
 2385                 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
 
 2389             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
 
 2390                 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
 
 2392             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
 
 2393             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
 
 2394             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
 
 2395             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
 
 2396             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
 
 2397             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
 
 2398             constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
 
 2399             constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
 
 2400             constexpr 
auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
 
 2401             constexpr 
auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
 
 2405             static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
 
 2406             static_assert(M5 == 4);
 
 2416                                 const index_t m_pos = block_m_id * MPerBlock +
 
 2417                                                       m0 * M2 * M1 * M3 * M4 * M5 +
 
 2418                                                       m1 * M2 * M3 * M4 * M5 +
 
 2419                                                       imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
 
 2420                                 if constexpr(MulRoutedWeight)
 
 2423                                         *c_style_pointer_cast<const vector_type<float, M5>*>(
 
 2424                                             p_ds_grid[
I2] + m_pos);
 
 2428                                         blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
 
 2429                                             make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
 
 2432                                     if constexpr(IsInputGemm) 
 
 2434                                         if constexpr(ActivationOperation ==
 
 2437                                             float gate = c_thread_buf[cidx];
 
 2438                                             float up   = c_thread_buf_up[cidx];
 
 2439                                             if constexpr(MulRoutedWeight)
 
 2441                                                 gate = gate * topk_weights.AsType<
float>()[m5];
 
 2442                                                 up   = up * topk_weights.AsType<
float>()[m5];
 
 2445                                             c_thread_buf_fp32(cidx) = gate * up;
 
 2449                                             float gate = c_thread_buf[cidx];
 
 2450                                             float up   = c_thread_buf_up[cidx];
 
 2451                                             if constexpr(MulRoutedWeight)
 
 2453                                                 gate = gate * topk_weights.AsType<
float>()[m5];
 
 2454                                                 up   = up * topk_weights.AsType<
float>()[m5];
 
 2457                                             c_thread_buf_fp32(cidx) = gate * up;
 
 2462                                         c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
 
 2463                                         if constexpr(MulRoutedWeight)
 
 2465                                             c_thread_buf_fp32(cidx) =
 
 2466                                                 topk_weights.AsType<
float>()[m5] *
 
 2467                                                 c_thread_buf_fp32[cidx];
 
 2477             constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 2480             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2481                 static_cast<CShuffleDataType*
>(p_shared_0),
 
 2482                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 2485                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 2511             const auto c_thread_mtx_on_block =
 
 2512                 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
 2514             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
 2515             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
 2517             const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
 2523             const auto m_thread_data_on_block_idx =
 
 2524                 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
 2527             const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
 2533             const auto n_thread_data_on_block_idx =
 
 2534                 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
 2541                 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 2542                 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 2545                          CShuffleNXdlPerWavePerShuffle / 
NXdlPack,
 
 2554                 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
 
 2559                 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2562                                        m_thread_data_on_block_idx[
I1],
 
 2563                                        n_thread_data_on_block_idx[
I1],
 
 2564                                        m_thread_data_on_block_idx[
I2],
 
 2565                                        n_thread_data_on_block_idx[
I2],
 
 2566                                        m_thread_data_on_block_idx[
I3],
 
 2567                                        m_thread_data_on_block_idx[
I4],
 
 2568                                        m_thread_data_on_block_idx[
I5],
 
 2569                                        n_thread_data_on_block_idx[
I3]),
 
 2572             using EDataType = CDataType;
 
 2577             const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
 
 2583                     return make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2584                         p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
 
 2590                 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
 2592                              { 
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
 
 2597                 tie(c_shuffle_block_buf),
 
 2599                              { 
return ds_grid_buf[i]; },
 
 2603             const auto idx_c_ds_block_begin =
 
 2613             const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
 
 2614                 c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 2616             using CDEBlockTransferCluster =
 
 2617                 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
 
 2618             const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
 
 2619             constexpr 
index_t scatter_weight_idx  = 3; 
 
 2624                    decltype(c_ds_desc_refs),
 
 2625                    decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
 
 2626                    CElementwiseOperation,
 
 2631                             CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2633                             CShuffleNXdlPerWavePerShuffle * 
NWave * NPerXdl>, 
 
 2634                    CDEBlockTransferCluster,
 
 2640                    CDEShuffleBlockTransferScalarPerVectors,
 
 2652                      idx_c_ds_block_begin,
 
 2653                      tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2657             auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2658                 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 2660             constexpr 
auto sfc_c_vgpr =
 
 2671                                   Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
 
 2673                                            CShuffleNXdlPerWavePerShuffle / 
NXdlPack,
 
 2683             constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
 2686             constexpr 
auto sfc_cde_block =
 
 2690                                            CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2692                                            CShuffleNXdlPerWavePerShuffle * 
NWave * NPerXdl>>{};
 
 2694             static_assert(num_access == sfc_cde_block.GetNumOfAccess(), 
"wrong!");
 
 2695             constexpr 
auto EMThreads =
 
 2696                 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
 
 2697             constexpr 
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
 
 2698             constexpr 
auto ENThreads =
 
 2699                 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
 
 2704                 auto dstidx = sfc_cde_block.GetIndex(access_id);
 
 2706                     block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
 
 2708                     const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
 
 2709                     IndexType token_offset    = fused_token & 0xffffff;
 
 2710                     if constexpr(IsInputGemm)
 
 2712                         token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 2714                     scatter_offsets(m0) = 
static_cast<IndexType
>(token_offset) * problem.
N;
 
 2720                 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2721                                               sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
 2723                                               c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2724                                               c_shuffle_block_buf);
 
 2730                 cde_block_copy_lds_and_global.Run(
 
 2733                     tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2737                 if constexpr(access_id < num_access - 1)
 
 2739                     constexpr 
auto cde_lds_and_global_step =
 
 2740                         sfc_cde_block.GetForwardStep(access_id);
 
 2744                         cde_block_copy_lds_and_global.MoveSrcSliceWindow(
 
 2745                             c_ds_desc_refs, i + 
I1, cde_lds_and_global_step);
 
 2749                     cde_block_copy_lds_and_global.MoveDstSliceWindow(
 
 2750                         tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2752                         cde_lds_and_global_step);
 
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
 
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
 
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
GemmSpecialization
Definition: gemm_specialization.hpp:11
 
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
 
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
 
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
 
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
 
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
 
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
 
InMemoryDataOperationEnum
Definition: ck.hpp:277
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp:36
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
 
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
 
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:90
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
 
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
 
Activation
Definition: gridwise_moe_gemm.hpp:31
 
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
 
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
 
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
 
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
 
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
 
constexpr bool is_same_v
Definition: type.hpp:283
 
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:299
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
 
integral_constant< index_t, N > Number
Definition: number.hpp:12
 
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:751
 
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:812
 
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:811
 
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:817
 
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:822
 
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:816
 
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:813
 
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:818
 
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:821
 
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:819
 
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:815
 
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:752
 
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:823
 
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:814
 
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:679
 
index_t AK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:743
 
index_t NPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:740
 
index_t KBatch
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:738
 
index_t BK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:744
 
index_t TopK
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:728
 
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:680
 
index_t KRead
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:741
 
index_t K
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:731
 
index_t MPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:739
 
index_t StrideC
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:737
 
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:735
 
index_t NumTokens
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:727
 
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:736
 
index_t MBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:745
 
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:733
 
index_t N
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:730
 
__host__ void Print() const
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:715
 
index_t StrideA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:732
 
index_t StrideB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:734
 
index_t M
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:729
 
index_t KPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:742
 
index_t NBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:746
 
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:827
 
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:869
 
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:870
 
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:867
 
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:828
 
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:868
 
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:177
 
static constexpr auto I6
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:187
 
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:305
 
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:197
 
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:195
 
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1047
 
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:255
 
static constexpr auto I1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:182
 
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:288
 
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:259
 
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:655
 
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1992
 
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1264
 
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:667
 
static constexpr index_t NLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:228
 
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:993
 
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:234
 
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:238
 
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:198
 
remove_cvref_t< decltype(BlockGemmMXBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1045
 
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1256
 
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:274
 
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:244
 
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:375
 
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:283
 
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:196
 
static constexpr auto I9
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:190
 
static constexpr auto I4
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:185
 
static constexpr auto I3
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:184
 
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:311
 
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:293
 
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:269
 
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:192
 
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:318
 
static constexpr index_t KRepeat
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:231
 
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:207
 
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1003
 
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:208
 
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:179
 
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:483
 
static constexpr auto I0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:181
 
static constexpr auto I5
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:186
 
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:237
 
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:200
 
static constexpr index_t KLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:229
 
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:873
 
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:206
 
static constexpr auto I8
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:189
 
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:279
 
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:601
 
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:610
 
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:201
 
static constexpr auto I2
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:183
 
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:591
 
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:216
 
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:204
 
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:323
 
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:217
 
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:225
 
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1249
 
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1071
 
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:178
 
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:299
 
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:634
 
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:334
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:257
 
static constexpr auto I7
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:188
 
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:474
 
static constexpr index_t NWave
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:230
 
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:202
 
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
 
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
 
Definition: sequence.hpp:43
 
Definition: tensor_space_filling_curve.hpp:20
 
Definition: static_buffer.hpp:75
 
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
 
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
Definition: tuple.hpp:117
 
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:26
 
Definition: data_type.hpp:42
 
Definition: integral_constant.hpp:20
 
Definition: data_type.hpp:187
 
Definition: functional2.hpp:33
 
Definition: device_base.hpp:197
 
Definition: unary_element_wise_operation.hpp:1041
 
Definition: unary_element_wise_operation.hpp:340
 
Definition: unary_element_wise_operation.hpp:1087
 
Definition: dtype_vector.hpp:10
 
#define CK_ENV(name)
Definition: env.hpp:129