36 template <
typename GridwiseGemm,
37 bool HasMainKBlockLoop,
42 #if CK_USE_LAUNCH_BOUNDS
49 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
51 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
53 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
54 karg.p_sorted_token_ids,
55 karg.p_sorted_expert_ids,
57 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
58 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
71 template <
typename GridwiseGemm,
72 bool HasMainKBlockLoop,
77 #if CK_USE_LAUNCH_BOUNDS
84 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
85 __shared__
char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
87 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
89 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
90 karg.p_sorted_token_ids,
91 karg.p_sorted_expert_ids,
93 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
94 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
108 template <
typename ALayout,
114 typename AccDataType,
115 typename CShuffleDataType,
118 typename AElementwiseOperation,
119 typename BElementwiseOperation,
120 typename CElementwiseOperation,
132 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
133 typename ABlockTransferThreadClusterArrangeOrder,
134 typename ABlockTransferSrcAccessOrder,
135 index_t ABlockTransferSrcVectorDim,
136 index_t ABlockTransferSrcScalarPerVector,
137 index_t ABlockTransferDstScalarPerVector_AK1,
138 bool AThreadTransferSrcResetCoordinateAfterRun,
140 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
141 typename BBlockTransferThreadClusterArrangeOrder,
142 typename BBlockTransferSrcAccessOrder,
143 index_t BBlockTransferSrcVectorDim,
144 index_t BBlockTransferSrcScalarPerVector,
145 index_t BBlockTransferDstScalarPerVector_BK1,
146 bool BThreadTransferSrcResetCoordinateAfterRun,
148 index_t CShuffleMXdlPerWavePerShuffle,
149 index_t CShuffleNXdlPerWavePerShuffle,
150 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
151 typename CDEShuffleBlockTransferScalarPerVectors,
154 index_t ActivationOperation = 0,
155 bool NSwizzle =
false,
156 bool IsInputGemm =
true,
157 bool MulRoutedWeight =
true,
158 bool PerTokenQuant =
false,
160 typename ComputeTypeA = CDataType,
161 typename ComputeTypeB = ComputeTypeA,
162 typename LDSTypeA = ADataType,
163 typename LDSTypeB = BDataType>
176 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
217 return static_cast<const DDataType*
>(
nullptr);
244 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
245 const index_t gridy = NSwizzle ? 1 : mblock;
276 auto K_t = K_Batch * KPerBlock;
277 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
282 auto K_t = K_Batch * KPerBlock;
283 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
288 auto K_t = K_Batch * KPerBlock;
289 return (K + K_t - 1) / K_t * KPerBlock;
295 auto K_t = K_Batch * KReadVec;
296 return (K + K_t - 1) / K_t * KReadVec;
309 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
325 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
327 const auto a_grid_desc_mraw_kraw = [&]() {
328 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
332 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
340 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
341 GemmSpec == GemmSpecialization::MNKPadding)
344 const auto a_grid_desc_m_k =
358 return a_grid_desc_ak0_m_ak1;
360 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
361 GemmSpec == GemmSpecialization::MNPadding)
365 a_grid_desc_mraw_kraw,
371 return a_grid_desc_ak0_m_ak1;
373 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
374 GemmSpec == GemmSpecialization::NKPadding)
378 a_grid_desc_mraw_kraw,
390 return a_grid_desc_ak0_m_ak1;
396 a_grid_desc_mraw_kraw,
402 return a_grid_desc_ak0_m_ak1;
408 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
409 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
413 make_tuple(
NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber,
I1));
419 const auto b_grid_desc_nraw_kraw = [&]() {
433 GemmSpec != GemmSpecialization::Default),
434 "pk_i4_t does not support padding");
436 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
437 GemmSpec == GemmSpecialization::MNKPadding)
440 const auto b_grid_desc_n_k =
454 return b_grid_desc_bk0_n_bk1;
456 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
457 GemmSpec == GemmSpecialization::MNPadding)
461 b_grid_desc_nraw_kraw,
467 return b_grid_desc_bk0_n_bk1;
469 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
470 GemmSpec == GemmSpecialization::MKPadding)
474 b_grid_desc_nraw_kraw,
486 return b_grid_desc_bk0_n_bk1;
492 b_grid_desc_nraw_kraw,
498 return b_grid_desc_bk0_n_bk1;
502 template <
typename ABlockDesc_AK0_M_AK1>
503 __host__ __device__
static constexpr
auto
506 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
508 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
511 template <
typename BBlockDesc_BK0_N_BK1>
512 __host__ __device__
static constexpr
auto
515 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
518 template <
typename ELayout>
520 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
522 const auto c_grid_desc_mraw_nraw = [&]() {
541 template <
typename DLayout>
542 __host__ __device__
static auto
545 const auto c_grid_desc_mraw_nraw = [&]() {
570 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
575 template <
typename DsGr
idDesc>
577 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
582 ds_grid_desc_m_n[i], MBlock, NBlock);
596 std::array<index_t, NumDTensor> StrideDs_,
622 std::cout <<
"problem {" <<
"NumTokens:" <<
NumTokens <<
", " <<
"TopK:" <<
TopK <<
", "
623 <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
626 <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", " <<
"AK0:" <<
AK0
627 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock <<
", "
628 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
655 const index_t* p_sorted_expert_ids_,
656 const index_t* p_max_token_id_,
657 const ADataType* p_a_grid_,
658 const BDataType* p_b_grid_,
659 std::array<const void*, NumDTensor> p_ds_grid_,
660 CDataType* p_c_grid_,
668 std::array<index_t, NumDTensor> StrideDs_,
671 AElementwiseOperation a_element_op_,
672 BElementwiseOperation b_element_op_,
673 CElementwiseOperation c_element_op_)
701 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
722 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
726 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
731 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
735 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
741 if(k_id < karg.
KBatch - 1)
757 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
758 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
761 if constexpr(ABlockLdsExtraM)
771 constexpr
auto a_lds_block_desc =
783 return a_lds_block_desc_permuted;
790 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
791 constexpr
auto M1 = MPerBlock / M0;
793 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
794 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
795 constexpr
auto KThreadRead = WaveSize / MPerXdl;
796 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
798 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(LDSTypeA) > 128)
800 : 128 / (
AK1Number * M0 *
sizeof(LDSTypeA));
801 constexpr
auto KThreadReadPerm =
802 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
803 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
807 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(LDSTypeA) > 128)
809 : ((128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA))) > M0
811 : 128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA)));
817 Number<kfold * M0 / mpair>{},
836 a_lds_block_desc_permuted,
858 a_lds_block_desc_unmerged,
861 Number<KThreadWrite / kfold / KThreadReadPerm>{},
870 return a_lds_block_desc_ak0_m_ak1;
883 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
885 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
892 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
910 ABlockTransferSrcScalarPerVector,
911 BBlockTransferSrcScalarPerVector,
930 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
933 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
936 constexpr
auto c_block_size =
937 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
940 c_block_size *
sizeof(CShuffleDataType));
946 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
947 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
948 "Invalid tuning param!");
956 if(!(karg.
M % MPerBlock == 0))
959 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M <<
" "
960 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
974 if(!(karg.
N % NPerBlock == 0))
977 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N <<
" "
978 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
992 auto K_t = karg.
KBatch * KPerBlock;
993 if(!(karg.
K % K_t == 0))
996 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
997 << karg.
K <<
" " << __FILE__ <<
":" << __LINE__
998 <<
", in function: " << __func__ << std::endl;
1007 auto K_t = karg.
KBatch * KReadVec;
1009 if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
1017 if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
1020 std::cout <<
"Arg K (" << karg.
K
1021 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1022 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1023 << __LINE__ <<
", in function: " << __func__ << std::endl;
1031 if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
1034 std::cout <<
"Arg M (" << karg.
M
1035 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1036 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1037 << __LINE__ <<
", in function: " << __func__ << std::endl;
1046 if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
1049 std::cout <<
"Arg N (" << karg.
N
1050 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1051 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1052 << __LINE__ <<
", in function: " << __func__ << std::endl;
1060 if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
1063 std::cout <<
"Arg K (" << karg.
K
1064 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1065 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1066 << __LINE__ <<
", in function: " << __func__ << std::endl;
1078 std::cout <<
"Arg N (" << karg.
N
1079 <<
") value is not a multiple of "
1080 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1082 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1093 std::cout <<
"Arg M (" << karg.
M
1094 <<
") value is not a multiple of "
1095 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1097 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1106 const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
1108 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1119 const index_t num_loop = K / KPerBlock;
1121 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1126 const index_t num_loop = K / KPerBlock;
1128 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1131 template <
typename CGr
idDesc>
1133 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1142 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1150 template <
bool HasMainKBlockLoop,
1154 const index_t* p_sorted_expert_ids,
1155 const index_t* p_max_token_id,
1156 const ADataType* p_a_grid,
1157 const BDataType* p_b_grid,
1159 CDataType* p_c_grid,
1162 AElementwiseOperation a_element_op,
1163 BElementwiseOperation b_element_op,
1164 CElementwiseOperation c_element_op)
1176 const auto b_grid_desc_bpreshuffled =
1178 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1184 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1187 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1189 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1190 if(expert_block_id * MPerBlock >= max_token_id)
1193 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1194 const auto block_mn = [&]() -> std::pair<int, int> {
1195 if constexpr(NSwizzle)
1197 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1199 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1200 const index_t expert_swizzle =
1201 ecnt > 0 ? ecnt : 1;
1202 const index_t bid_new = blockIdx.x - prefix_block;
1203 const index_t nid = __builtin_amdgcn_readfirstlane(
1204 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1206 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1211 return {blockIdx.x, blockIdx.y};
1215 const index_t block_n_id = block_mn.first;
1216 const index_t block_m_id = block_mn.second;
1218 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1221 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1222 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1223 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1224 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1225 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1226 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1228 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
1232 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1233 index_t token_offset = fused_token & 0xffffff;
1234 if constexpr(!IsInputGemm)
1236 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1238 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
1240 const IndexType expert_stride =
1241 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
1242 const IndexType expert_offset = expert_id * expert_stride /
BPackedSize;
1244 const index_t n_block_data_idx_on_grid =
1245 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1247 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1248 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1249 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1250 p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1260 AElementwiseOperation,
1264 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1265 ABlockTransferThreadClusterArrangeOrder,
1268 decltype(a_grid_desc_ak0_m_ak1),
1269 decltype(a_block_desc_ak0_m_ak1),
1270 ABlockTransferSrcAccessOrder,
1272 ABlockTransferSrcVectorDim,
1274 ABlockTransferSrcScalarPerVector,
1275 ABlockTransferDstScalarPerVector_AK1,
1278 AThreadTransferSrcResetCoordinateAfterRun,
1282 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1285 a_block_desc_ak0_m_ak1,
1292 auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1293 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1298 decltype(b_grid_desc_bpreshuffled),
1299 decltype(b_block_desc_bk0_n_bk1),
1303 BBlockTransferSrcScalarPerVector,
1304 BThreadTransferSrcResetCoordinateAfterRun,
1305 true>(b_grid_desc_bpreshuffled,
1313 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1314 static_cast<LDSTypeA*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1320 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1322 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1323 decltype(c_thread_buf) c_thread_buf_up;
1327 c_thread_buf.num_of_v_,
1328 c_thread_buf.s_per_v,
1332 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1333 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1335 if constexpr(IsInputGemm)
1337 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
1338 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1339 p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1343 decltype(b_grid_desc_bpreshuffled),
1344 decltype(b_block_desc_bk0_n_bk1),
1348 BBlockTransferSrcScalarPerVector,
1349 BThreadTransferSrcResetCoordinateAfterRun,
1350 true>(b_grid_desc_bpreshuffled,
1356 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1357 a_grid_desc_ak0_m_ak1,
1358 a_block_desc_ak0_m_ak1,
1362 a_block_slice_copy_step,
1363 b_grid_desc_bpreshuffled,
1365 b_blockwise_copy_up,
1369 b_block_slice_copy_step,
1372 num_k_block_main_loop);
1376 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1377 a_grid_desc_ak0_m_ak1,
1378 a_block_desc_ak0_m_ak1,
1382 a_block_slice_copy_step,
1383 b_grid_desc_bpreshuffled,
1387 b_block_slice_copy_step,
1389 num_k_block_main_loop);
1394 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1395 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1398 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1401 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1402 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1406 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1407 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1409 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1410 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1411 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1412 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1413 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1414 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1415 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1416 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1419 const float* p_sorted_weights_0 = p_ds_grid[
I0];
1420 const float* p_scale_b = p_ds_grid[
I1];
1422 static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
1423 static_assert(M4 == 4 || M4 == 8);
1427 if(p_sorted_weights_0 !=
nullptr && p_scale_b !=
nullptr)
1429 if constexpr(PerTokenQuant)
1431 constexpr
index_t scale_stride = (IsInputGemm ? 2 : 1);
1432 p_scale_b += expert_id * problem.
N * scale_stride + block_n_id * NPerBlock +
1437 p_scale_b += expert_id;
1443 const float scale_b = p_scale_b[n0 *
NWave * NPerXdl * PerTokenQuant];
1446 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1447 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1448 if constexpr(PerTokenQuant)
1451 *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
1452 p_sorted_token_ids + m_pos);
1454 if constexpr(MulRoutedWeight)
1456 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1457 p_ds_grid[
I2] + m_pos);
1460 float scale_a = [&]() {
1461 if constexpr(PerTokenQuant)
1464 scale_token_ids.template AsType<index_t>()[m4];
1465 const index_t token_offset = fused_token & 0xffffff;
1467 ? p_sorted_weights_0[IsInputGemm
1477 return p_sorted_weights_0[0];
1481 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1484 if constexpr(IsInputGemm)
1488 const float scale_up =
1489 p_scale_b[(n0 *
NWave * NPerXdl + problem.
N) *
1491 float gate = scale_a * scale_b * c_thread_buf[cidx];
1492 float up = scale_a * scale_up * c_thread_buf_up[cidx];
1493 if constexpr(MulRoutedWeight)
1495 gate = gate * topk_weights.template AsType<float>()[m4];
1496 up = up * topk_weights.template AsType<float>()[m4];
1504 c_thread_buf_fp32(cidx) = gate * up;
1508 const float scale_up =
1509 p_scale_b[(n0 *
NWave * NPerXdl + problem.
N) *
1511 float gate = scale_a * scale_b * c_thread_buf[cidx];
1512 float up = scale_a * scale_up * c_thread_buf_up[cidx];
1513 if constexpr(MulRoutedWeight)
1515 gate = gate * topk_weights.template AsType<float>()[m4];
1516 up = up * topk_weights.template AsType<float>()[m4];
1524 c_thread_buf_fp32(cidx) = gate * up;
1529 c_thread_buf_fp32(cidx) =
1530 scale_a * scale_b * c_thread_buf[cidx];
1531 if constexpr(MulRoutedWeight)
1533 c_thread_buf_fp32(cidx) =
1534 c_thread_buf_fp32(cidx) *
1535 topk_weights.template AsType<float>()[m4];
1549 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1550 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1551 if constexpr(MulRoutedWeight)
1553 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1554 p_ds_grid[
I2] + m_pos);
1558 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1562 if constexpr(IsInputGemm)
1566 float gate = c_thread_buf[cidx];
1567 float up = c_thread_buf_up[cidx];
1568 if constexpr(MulRoutedWeight)
1570 gate = gate * topk_weights.template AsType<float>()[m4];
1571 up = up * topk_weights.template AsType<float>()[m4];
1574 c_thread_buf_fp32(cidx) = gate * up;
1578 float gate = c_thread_buf[cidx];
1579 float up = c_thread_buf_up[cidx];
1580 if constexpr(MulRoutedWeight)
1582 gate = gate * topk_weights.template AsType<float>()[m4];
1583 up = up * topk_weights.template AsType<float>()[m4];
1586 c_thread_buf_fp32(cidx) = gate * up;
1591 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1592 if constexpr(MulRoutedWeight)
1594 c_thread_buf_fp32(cidx) =
1595 topk_weights.template AsType<float>()[m4] *
1596 c_thread_buf_fp32[cidx];
1605 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1608 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1609 static_cast<CShuffleDataType*
>(p_shared),
1610 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1613 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1633 const auto c_thread_mtx_on_block =
1634 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1636 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1637 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1639 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1645 const auto m_thread_data_on_block_idx =
1646 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1649 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1655 const auto n_thread_data_on_block_idx =
1656 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1660 auto c_thread_copy_vgpr_to_lds =
1663 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1664 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1666 Sequence<CShuffleMXdlPerWavePerShuffle,
1667 CShuffleNXdlPerWavePerShuffle,
1680 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1683 m_thread_data_on_block_idx[
I1],
1684 n_thread_data_on_block_idx[
I1],
1685 m_thread_data_on_block_idx[
I2],
1686 m_thread_data_on_block_idx[
I3],
1687 m_thread_data_on_block_idx[
I4],
1688 n_thread_data_on_block_idx[
I2]),
1691 using EDataType = CDataType;
1696 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1702 return make_dynamic_buffer<AddressSpaceEnum::Global>(
1703 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1709 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1711 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1716 tie(c_shuffle_block_buf),
1718 {
return ds_grid_buf[i]; },
1722 const auto idx_c_ds_block_begin =
1732 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1733 c_grid_desc_mblock_mperblock_nblock_nperblock;
1735 using CDEBlockTransferCluster =
1736 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1737 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1738 constexpr
index_t scatter_weight_idx = 3;
1743 decltype(c_ds_desc_refs),
1744 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1745 CElementwiseOperation,
1749 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1751 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
1752 CDEBlockTransferCluster,
1758 CDEShuffleBlockTransferScalarPerVectors,
1770 idx_c_ds_block_begin,
1771 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1775 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1776 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1777 constexpr
auto sfc_c_vgpr =
1780 Sequence<CShuffleMXdlPerWavePerShuffle,
1781 CShuffleNXdlPerWavePerShuffle,
1789 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1792 constexpr
auto sfc_cde_block =
1796 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1798 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
1800 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
1801 constexpr
auto EMThreads =
1802 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
1803 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1804 constexpr
auto ENThreads =
1805 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
1810 auto dstidx = sfc_cde_block.GetIndex(access_id);
1812 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
1814 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1815 IndexType token_offset = fused_token & 0xffffff;
1816 if constexpr(IsInputGemm)
1818 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1820 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
N;
1826 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1827 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1829 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1830 c_shuffle_block_buf);
1836 cde_block_copy_lds_and_global.Run(
1839 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1843 if constexpr(access_id < num_access - 1)
1845 constexpr
auto cde_lds_and_global_step =
1846 sfc_cde_block.GetForwardStep(access_id);
1850 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1851 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
1855 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1856 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1858 cde_lds_and_global_step);
1864 template <
bool HasMainKBlockLoop,
1868 const index_t* p_sorted_expert_ids,
1869 const index_t* p_max_token_id,
1870 const ADataType* p_a_grid,
1871 const BDataType* p_b_grid,
1873 CDataType* p_c_grid,
1877 AElementwiseOperation a_element_op,
1878 BElementwiseOperation b_element_op,
1879 CElementwiseOperation c_element_op)
1891 const auto b_grid_desc_bpreshuffled =
1893 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1899 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1902 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1904 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1905 if(expert_block_id * MPerBlock >= max_token_id)
1908 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1909 const auto block_mn = [&]() -> std::pair<int, int> {
1910 if constexpr(NSwizzle)
1912 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1914 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1915 const index_t expert_swizzle =
1916 ecnt > 0 ? ecnt : 1;
1917 const index_t bid_new = blockIdx.x - prefix_block;
1918 const index_t nid = __builtin_amdgcn_readfirstlane(
1919 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1921 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1926 return {blockIdx.x, blockIdx.y};
1930 const index_t block_n_id = block_mn.first;
1931 const index_t block_m_id = block_mn.second;
1933 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1936 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1937 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1938 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1939 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1940 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1941 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1943 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
1947 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1948 index_t token_offset = fused_token & 0xffffff;
1949 if constexpr(!IsInputGemm)
1951 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1953 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
1955 const IndexType expert_stride =
1956 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
1957 const IndexType expert_offset = expert_id * expert_stride /
BPackedSize;
1959 const index_t n_block_data_idx_on_grid =
1960 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1962 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1963 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1964 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1965 p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1976 AElementwiseOperation,
1980 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1981 ABlockTransferThreadClusterArrangeOrder,
1984 decltype(a_grid_desc_ak0_m_ak1),
1985 decltype(a_block_desc_ak0_m_ak1),
1986 ABlockTransferSrcAccessOrder,
1988 ABlockTransferSrcVectorDim,
1990 ABlockTransferSrcScalarPerVector,
1991 ABlockTransferDstScalarPerVector_AK1,
1994 AThreadTransferSrcResetCoordinateAfterRun,
1998 2>(a_grid_desc_ak0_m_ak1,
2001 a_block_desc_ak0_m_ak1,
2008 auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2009 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2010 auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2011 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2012 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2017 decltype(b_grid_desc_bpreshuffled),
2018 decltype(b_block_desc_bk0_n_bk1),
2022 BBlockTransferSrcScalarPerVector,
2023 BThreadTransferSrcResetCoordinateAfterRun,
2024 true>(b_grid_desc_bpreshuffled,
2032 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2033 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2034 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2035 static_cast<ADataType*
>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2036 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2042 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2044 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2045 decltype(c_thread_buf) c_thread_buf_up;
2049 c_thread_buf.num_of_v_,
2050 c_thread_buf.s_per_v,
2054 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2055 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2058 if constexpr(IsInputGemm)
2060 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
2061 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2062 p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2066 decltype(b_grid_desc_bpreshuffled),
2067 decltype(b_block_desc_bk0_n_bk1),
2071 BBlockTransferSrcScalarPerVector,
2072 BThreadTransferSrcResetCoordinateAfterRun,
2073 true>(b_grid_desc_bpreshuffled,
2078 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2079 a_grid_desc_ak0_m_ak1,
2080 a_block_desc_ak0_m_ak1,
2084 a_block_slice_copy_step,
2085 b_grid_desc_bpreshuffled,
2087 b_blockwise_copy_up,
2091 b_block_slice_copy_step,
2094 num_k_block_main_loop);
2099 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2100 a_grid_desc_ak0_m_ak1,
2101 a_block_desc_ak0_m_ak1,
2105 a_block_slice_copy_step,
2106 b_grid_desc_bpreshuffled,
2110 b_block_slice_copy_step,
2112 num_k_block_main_loop);
2117 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2118 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2121 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2124 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2125 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2129 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2130 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2132 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2133 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2134 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2135 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2136 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2137 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2138 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2139 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2142 const float* p_sorted_weights_0 = p_ds_grid[
I0];
2143 const float* p_scale_b = p_ds_grid[
I1];
2145 static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2146 static_assert(M4 == 4 || M4 == 8);
2150 if(p_sorted_weights_0 !=
nullptr && p_scale_b !=
nullptr)
2152 if constexpr(PerTokenQuant)
2154 constexpr
index_t scale_stride = (IsInputGemm ? 2 : 1);
2155 p_scale_b += expert_id * problem.
N * scale_stride + block_n_id * NPerBlock +
2160 p_scale_b += expert_id;
2166 const float scale_b = p_scale_b[n0 *
NWave * NPerXdl * PerTokenQuant];
2169 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2170 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2171 if constexpr(PerTokenQuant)
2174 *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
2175 p_sorted_token_ids + m_pos);
2177 if constexpr(MulRoutedWeight)
2179 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2180 p_ds_grid[
I2] + m_pos);
2183 float scale_a = [&]() {
2184 if constexpr(PerTokenQuant)
2187 scale_token_ids.template AsType<index_t>()[m4];
2188 const index_t token_offset = fused_token & 0xffffff;
2190 ? p_sorted_weights_0[IsInputGemm
2200 return p_sorted_weights_0[0];
2204 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2207 if constexpr(IsInputGemm)
2211 const float scale_up =
2212 p_scale_b[(n0 *
NWave * NPerXdl + problem.
N) *
2214 float gate = scale_a * scale_b * c_thread_buf[cidx];
2215 float up = scale_a * scale_up * c_thread_buf_up[cidx];
2216 if constexpr(MulRoutedWeight)
2218 gate = gate * topk_weights.template AsType<float>()[m4];
2219 up = up * topk_weights.template AsType<float>()[m4];
2227 c_thread_buf_fp32(cidx) = gate * up;
2231 const float scale_up =
2232 p_scale_b[(n0 *
NWave * NPerXdl + problem.
N) *
2234 float gate = scale_a * scale_b * c_thread_buf[cidx];
2235 float up = scale_a * scale_up * c_thread_buf_up[cidx];
2236 if constexpr(MulRoutedWeight)
2238 gate = gate * topk_weights.template AsType<float>()[m4];
2239 up = up * topk_weights.template AsType<float>()[m4];
2247 c_thread_buf_fp32(cidx) = gate * up;
2252 c_thread_buf_fp32(cidx) =
2253 scale_a * scale_b * c_thread_buf[cidx];
2254 if constexpr(MulRoutedWeight)
2256 c_thread_buf_fp32(cidx) =
2257 c_thread_buf_fp32(cidx) *
2258 topk_weights.template AsType<float>()[m4];
2272 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2273 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2274 if constexpr(MulRoutedWeight)
2276 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2277 p_ds_grid[
I2] + m_pos);
2281 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2285 if constexpr(IsInputGemm)
2289 float gate = c_thread_buf[cidx];
2290 float up = c_thread_buf_up[cidx];
2291 if constexpr(MulRoutedWeight)
2293 gate = gate * topk_weights.template AsType<float>()[m4];
2294 up = up * topk_weights.template AsType<float>()[m4];
2297 c_thread_buf_fp32(cidx) = gate * up;
2301 float gate = c_thread_buf[cidx];
2302 float up = c_thread_buf_up[cidx];
2303 if constexpr(MulRoutedWeight)
2305 gate = gate * topk_weights.template AsType<float>()[m4];
2306 up = up * topk_weights.template AsType<float>()[m4];
2309 c_thread_buf_fp32(cidx) = gate * up;
2314 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2315 if constexpr(MulRoutedWeight)
2317 c_thread_buf_fp32(cidx) =
2318 topk_weights.template AsType<float>()[m4] *
2319 c_thread_buf_fp32[cidx];
2328 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2331 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2332 static_cast<CShuffleDataType*
>(p_shared),
2333 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2336 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2356 const auto c_thread_mtx_on_block =
2357 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2359 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2360 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2362 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2368 const auto m_thread_data_on_block_idx =
2369 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2372 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2378 const auto n_thread_data_on_block_idx =
2379 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2383 auto c_thread_copy_vgpr_to_lds =
2386 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2387 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2389 Sequence<CShuffleMXdlPerWavePerShuffle,
2390 CShuffleNXdlPerWavePerShuffle,
2403 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2406 m_thread_data_on_block_idx[
I1],
2407 n_thread_data_on_block_idx[
I1],
2408 m_thread_data_on_block_idx[
I2],
2409 m_thread_data_on_block_idx[
I3],
2410 m_thread_data_on_block_idx[
I4],
2411 n_thread_data_on_block_idx[
I2]),
2414 using EDataType = CDataType;
2419 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2425 return make_dynamic_buffer<AddressSpaceEnum::Global>(
2426 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2432 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2434 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2439 tie(c_shuffle_block_buf),
2441 {
return ds_grid_buf[i]; },
2445 const auto idx_c_ds_block_begin =
2455 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2456 c_grid_desc_mblock_mperblock_nblock_nperblock;
2458 using CDEBlockTransferCluster =
2459 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2460 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2461 constexpr
index_t scatter_weight_idx = 3;
2466 decltype(c_ds_desc_refs),
2467 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2468 CElementwiseOperation,
2472 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2474 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
2475 CDEBlockTransferCluster,
2481 CDEShuffleBlockTransferScalarPerVectors,
2493 idx_c_ds_block_begin,
2494 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2498 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2499 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2500 constexpr
auto sfc_c_vgpr =
2503 Sequence<CShuffleMXdlPerWavePerShuffle,
2504 CShuffleNXdlPerWavePerShuffle,
2512 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2515 constexpr
auto sfc_cde_block =
2519 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2521 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
2523 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2524 constexpr
auto EMThreads =
2525 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2526 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2527 constexpr
auto ENThreads =
2528 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2533 auto dstidx = sfc_cde_block.GetIndex(access_id);
2535 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2537 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2538 IndexType token_offset = fused_token & 0xffffff;
2539 if constexpr(IsInputGemm)
2541 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2543 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
N;
2549 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2550 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2552 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2553 c_shuffle_block_buf);
2559 cde_block_copy_lds_and_global.Run(
2562 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2566 if constexpr(access_id < num_access - 1)
2568 constexpr
auto cde_lds_and_global_step =
2569 sfc_cde_block.GetForwardStep(access_id);
2573 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2574 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2578 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2579 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2581 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:56
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr auto BlockGemmBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp:41
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
Definition: gridwise_moe_gemm.hpp:653
const BDataType * p_b_grid
Definition: gridwise_moe_gemm.hpp:709
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm.hpp:705
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm.hpp:706
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm.hpp:713
const ADataType * p_a_grid
Definition: gridwise_moe_gemm.hpp:708
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm.hpp:654
const index_t * p_max_token_id
Definition: gridwise_moe_gemm.hpp:707
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm.hpp:714
CDataType * p_c_grid
Definition: gridwise_moe_gemm.hpp:711
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm.hpp:710
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm.hpp:715
Definition: gridwise_moe_gemm.hpp:588
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm.hpp:638
index_t NumTokens
Definition: gridwise_moe_gemm.hpp:631
index_t MBlock
Definition: gridwise_moe_gemm.hpp:647
index_t TopK
Definition: gridwise_moe_gemm.hpp:632
index_t K
Definition: gridwise_moe_gemm.hpp:635
__host__ void Print() const
Definition: gridwise_moe_gemm.hpp:620
index_t NPadded
Definition: gridwise_moe_gemm.hpp:642
index_t BK0
Definition: gridwise_moe_gemm.hpp:646
index_t KRead
Definition: gridwise_moe_gemm.hpp:643
index_t MPadded
Definition: gridwise_moe_gemm.hpp:641
index_t AK0
Definition: gridwise_moe_gemm.hpp:645
index_t StrideA
Definition: gridwise_moe_gemm.hpp:636
index_t StrideC
Definition: gridwise_moe_gemm.hpp:639
index_t M
Definition: gridwise_moe_gemm.hpp:633
index_t KBatch
Definition: gridwise_moe_gemm.hpp:640
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm.hpp:589
index_t KPadded
Definition: gridwise_moe_gemm.hpp:644
index_t StrideB
Definition: gridwise_moe_gemm.hpp:637
index_t N
Definition: gridwise_moe_gemm.hpp:634
index_t NBlock
Definition: gridwise_moe_gemm.hpp:648
Definition: gridwise_moe_gemm.hpp:719
index_t a_k_split_offset
Definition: gridwise_moe_gemm.hpp:751
index_t b_k_split_offset
Definition: gridwise_moe_gemm.hpp:752
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm.hpp:720
Definition: gridwise_moe_gemm.hpp:165
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm.hpp:240
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:292
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm.hpp:211
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:286
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm.hpp:204
remove_cvref_t< decltype(BlockGemmBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm.hpp:920
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm.hpp:255
static constexpr index_t NLane
Definition: gridwise_moe_gemm.hpp:206
static constexpr auto I5
Definition: gridwise_moe_gemm.hpp:171
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm.hpp:543
static constexpr auto BK0Number
Definition: gridwise_moe_gemm.hpp:179
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm.hpp:324
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm.hpp:184
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm.hpp:1124
static constexpr auto I2
Definition: gridwise_moe_gemm.hpp:168
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm.hpp:226
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm.hpp:299
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm.hpp:224
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm.hpp:406
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm.hpp:416
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm.hpp:513
static constexpr auto I6
Definition: gridwise_moe_gemm.hpp:172
static constexpr auto I0
Definition: gridwise_moe_gemm.hpp:166
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm.hpp:209
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm.hpp:1117
static constexpr auto I1
Definition: gridwise_moe_gemm.hpp:167
static constexpr auto I4
Definition: gridwise_moe_gemm.hpp:170
static constexpr auto AK1Number
Definition: gridwise_moe_gemm.hpp:180
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:274
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm.hpp:304
static constexpr auto BK1Number
Definition: gridwise_moe_gemm.hpp:181
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm.hpp:182
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm.hpp:233
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm.hpp:519
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm.hpp:264
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm.hpp:222
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1867
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:280
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm.hpp:564
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm.hpp:944
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm.hpp:881
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm.hpp:504
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm.hpp:175
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm.hpp:250
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:1132
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm.hpp:922
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm.hpp:874
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm.hpp:260
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1153
static constexpr index_t KPack
Definition: gridwise_moe_gemm.hpp:187
static constexpr index_t NWave
Definition: gridwise_moe_gemm.hpp:207
static constexpr auto I3
Definition: gridwise_moe_gemm.hpp:169
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm.hpp:269
static constexpr auto AK0Number
Definition: gridwise_moe_gemm.hpp:178
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:576
static constexpr index_t KGroup
Definition: gridwise_moe_gemm.hpp:192
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm.hpp:310
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm.hpp:755
static constexpr index_t KLane
Definition: gridwise_moe_gemm.hpp:189
static constexpr auto I7
Definition: gridwise_moe_gemm.hpp:173
Definition: xdlops_gemm.hpp:1126
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1700
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1647
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1694
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10