36 template <
typename GridwiseGemm,
37 bool HasMainKBlockLoop,
42 #if CK_USE_LAUNCH_BOUNDS
48 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
49 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
51 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
53 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
55 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56 karg.p_sorted_token_ids,
57 karg.p_sorted_expert_ids,
59 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
74 template <
typename GridwiseGemm,
75 bool HasMainKBlockLoop,
80 #if CK_USE_LAUNCH_BOUNDS
86 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
87 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
89 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
90 __shared__
char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
92 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
94 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
95 karg.p_sorted_token_ids,
96 karg.p_sorted_expert_ids,
98 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
99 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
114 template <
typename ALayout,
120 typename AccDataType,
121 typename CShuffleDataType,
124 typename AElementwiseOperation,
125 typename BElementwiseOperation,
126 typename CElementwiseOperation,
138 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
139 typename ABlockTransferThreadClusterArrangeOrder,
140 typename ABlockTransferSrcAccessOrder,
141 index_t ABlockTransferSrcVectorDim,
142 index_t ABlockTransferSrcScalarPerVector,
143 index_t ABlockTransferDstScalarPerVector_AK1,
144 bool AThreadTransferSrcResetCoordinateAfterRun,
146 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
147 typename BBlockTransferThreadClusterArrangeOrder,
148 typename BBlockTransferSrcAccessOrder,
149 index_t BBlockTransferSrcVectorDim,
150 index_t BBlockTransferSrcScalarPerVector,
151 index_t BBlockTransferDstScalarPerVector_BK1,
152 bool BThreadTransferSrcResetCoordinateAfterRun,
154 index_t CShuffleMXdlPerWavePerShuffle,
155 index_t CShuffleNXdlPerWavePerShuffle,
156 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
157 typename CDEShuffleBlockTransferScalarPerVectors,
160 index_t ActivationOperation = 0,
161 bool NSwizzle =
false,
162 bool IsInputGemm =
true,
163 bool MulRoutedWeight =
true,
164 bool PerTokenQuant =
false,
166 typename ComputeTypeA = CDataType,
167 typename ComputeTypeB = ComputeTypeA,
168 typename LDSTypeA = ADataType,
169 typename LDSTypeB = BDataType>
182 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
223 return static_cast<const DDataType*
>(
nullptr);
250 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
251 const index_t gridy = NSwizzle ? 1 : mblock;
282 auto K_t = K_Batch * KPerBlock;
283 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
288 auto K_t = K_Batch * KPerBlock;
289 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
294 auto K_t = K_Batch * KPerBlock;
295 return (K + K_t - 1) / K_t * KPerBlock;
301 auto K_t = K_Batch * KReadVec;
302 return (K + K_t - 1) / K_t * KReadVec;
315 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
331 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
333 const auto a_grid_desc_mraw_kraw = [&]() {
334 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
338 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
346 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
347 GemmSpec == GemmSpecialization::MNKPadding)
350 const auto a_grid_desc_m_k =
364 return a_grid_desc_ak0_m_ak1;
366 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
367 GemmSpec == GemmSpecialization::MNPadding)
371 a_grid_desc_mraw_kraw,
377 return a_grid_desc_ak0_m_ak1;
379 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
380 GemmSpec == GemmSpecialization::NKPadding)
384 a_grid_desc_mraw_kraw,
396 return a_grid_desc_ak0_m_ak1;
402 a_grid_desc_mraw_kraw,
408 return a_grid_desc_ak0_m_ak1;
414 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
415 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
419 make_tuple(
NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber,
I1));
425 const auto b_grid_desc_nraw_kraw = [&]() {
439 GemmSpec != GemmSpecialization::Default),
440 "pk_i4_t does not support padding");
442 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
443 GemmSpec == GemmSpecialization::MNKPadding)
446 const auto b_grid_desc_n_k =
460 return b_grid_desc_bk0_n_bk1;
462 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
463 GemmSpec == GemmSpecialization::MNPadding)
467 b_grid_desc_nraw_kraw,
473 return b_grid_desc_bk0_n_bk1;
475 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
476 GemmSpec == GemmSpecialization::MKPadding)
480 b_grid_desc_nraw_kraw,
492 return b_grid_desc_bk0_n_bk1;
498 b_grid_desc_nraw_kraw,
504 return b_grid_desc_bk0_n_bk1;
508 template <
typename ABlockDesc_AK0_M_AK1>
509 __host__ __device__
static constexpr
auto
512 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
514 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
517 template <
typename BBlockDesc_BK0_N_BK1>
518 __host__ __device__
static constexpr
auto
521 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
524 template <
typename ELayout>
526 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
528 const auto c_grid_desc_mraw_nraw = [&]() {
547 template <
typename DLayout>
548 __host__ __device__
static auto
551 const auto c_grid_desc_mraw_nraw = [&]() {
576 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
581 template <
typename DsGr
idDesc>
583 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
588 ds_grid_desc_m_n[i], MBlock, NBlock);
602 std::array<index_t, NumDTensor> StrideDs_,
628 std::cout <<
"problem {" <<
"NumTokens:" <<
NumTokens <<
", " <<
"TopK:" <<
TopK <<
", "
629 <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
632 <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", " <<
"AK0:" <<
AK0
633 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock <<
", "
634 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
661 const index_t* p_sorted_expert_ids_,
662 const index_t* p_max_token_id_,
663 const ADataType* p_a_grid_,
664 const BDataType* p_b_grid_,
665 std::array<const void*, NumDTensor> p_ds_grid_,
666 CDataType* p_c_grid_,
674 std::array<index_t, NumDTensor> StrideDs_,
677 AElementwiseOperation a_element_op_,
678 BElementwiseOperation b_element_op_,
679 CElementwiseOperation c_element_op_)
707 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
728 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
732 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
737 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
741 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
747 if(k_id < karg.
KBatch - 1)
763 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
764 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
767 if constexpr(ABlockLdsExtraM)
777 constexpr
auto a_lds_block_desc =
789 return a_lds_block_desc_permuted;
796 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
797 constexpr
auto M1 = MPerBlock / M0;
799 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
800 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
801 constexpr
auto KThreadRead = WaveSize / MPerXdl;
802 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
804 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(LDSTypeA) > 128)
806 : 128 / (
AK1Number * M0 *
sizeof(LDSTypeA));
807 constexpr
auto KThreadReadPerm =
808 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
809 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
813 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(LDSTypeA) > 128)
815 : ((128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA))) > M0
817 : 128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA)));
823 Number<kfold * M0 / mpair>{},
842 a_lds_block_desc_permuted,
864 a_lds_block_desc_unmerged,
867 Number<KThreadWrite / kfold / KThreadReadPerm>{},
876 return a_lds_block_desc_ak0_m_ak1;
889 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
891 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
898 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
916 ABlockTransferSrcScalarPerVector,
917 BBlockTransferSrcScalarPerVector,
936 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
939 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
942 constexpr
auto c_block_size =
943 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
946 c_block_size *
sizeof(CShuffleDataType));
954 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
955 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
956 "Invalid tuning param!");
964 if(!(karg.M % MPerBlock == 0))
967 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
968 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
982 if(!(karg.N % NPerBlock == 0))
985 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
986 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1000 auto K_t = karg.KBatch * KPerBlock;
1001 if(!(karg.K % K_t == 0))
1004 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1005 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1006 <<
", in function: " << __func__ << std::endl;
1015 auto K_t = karg.KBatch * KReadVec;
1017 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1025 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1028 std::cout <<
"Arg K (" << karg.K
1029 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1030 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1031 << __LINE__ <<
", in function: " << __func__ << std::endl;
1039 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1042 std::cout <<
"Arg M (" << karg.M
1043 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1044 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1045 << __LINE__ <<
", in function: " << __func__ << std::endl;
1054 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1057 std::cout <<
"Arg N (" << karg.N
1058 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1059 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1060 << __LINE__ <<
", in function: " << __func__ << std::endl;
1068 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1071 std::cout <<
"Arg K (" << karg.K
1072 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1073 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1074 << __LINE__ <<
", in function: " << __func__ << std::endl;
1086 std::cout <<
"Arg N (" << karg.N
1087 <<
") value is not a multiple of "
1088 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1090 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1101 std::cout <<
"Arg M (" << karg.M
1102 <<
") value is not a multiple of "
1103 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1105 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1114 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1116 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1127 const index_t num_loop = K / KPerBlock;
1129 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1134 const index_t num_loop = K / KPerBlock;
1136 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1139 template <
typename CGr
idDesc>
1141 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1150 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1158 template <
bool HasMainKBlockLoop,
1162 const index_t* p_sorted_expert_ids,
1163 const index_t* p_max_token_id,
1164 const ADataType* p_a_grid,
1165 const BDataType* p_b_grid,
1167 CDataType* p_c_grid,
1170 AElementwiseOperation a_element_op,
1171 BElementwiseOperation b_element_op,
1172 CElementwiseOperation c_element_op)
1184 const auto b_grid_desc_bpreshuffled =
1186 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1192 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1195 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1197 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1198 if(expert_block_id * MPerBlock >= max_token_id)
1201 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1202 const auto block_mn = [&]() -> std::pair<int, int> {
1203 if constexpr(NSwizzle)
1205 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1207 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1208 const index_t expert_swizzle =
1209 ecnt > 0 ? ecnt : 1;
1210 const index_t bid_new = blockIdx.x - prefix_block;
1211 const index_t nid = __builtin_amdgcn_readfirstlane(
1212 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1214 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1219 return {blockIdx.x, blockIdx.y};
1223 const index_t block_n_id = block_mn.first;
1224 const index_t block_m_id = block_mn.second;
1226 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1229 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1230 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1231 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1232 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1233 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1234 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1236 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
1240 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1241 index_t token_offset = fused_token & 0xffffff;
1242 if constexpr(!IsInputGemm)
1244 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1246 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
1248 const IndexType expert_stride =
1249 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
1250 const IndexType expert_offset = expert_id * expert_stride /
BPackedSize;
1252 const index_t n_block_data_idx_on_grid =
1253 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1255 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1256 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1257 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1258 p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1268 AElementwiseOperation,
1272 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1273 ABlockTransferThreadClusterArrangeOrder,
1276 decltype(a_grid_desc_ak0_m_ak1),
1277 decltype(a_block_desc_ak0_m_ak1),
1278 ABlockTransferSrcAccessOrder,
1280 ABlockTransferSrcVectorDim,
1282 ABlockTransferSrcScalarPerVector,
1283 ABlockTransferDstScalarPerVector_AK1,
1286 AThreadTransferSrcResetCoordinateAfterRun,
1290 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1293 a_block_desc_ak0_m_ak1,
1300 auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1301 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1306 decltype(b_grid_desc_bpreshuffled),
1307 decltype(b_block_desc_bk0_n_bk1),
1311 BBlockTransferSrcScalarPerVector,
1312 BThreadTransferSrcResetCoordinateAfterRun,
1313 true>(b_grid_desc_bpreshuffled,
1321 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1322 static_cast<LDSTypeA*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1328 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1330 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1331 decltype(c_thread_buf) c_thread_buf_up;
1335 c_thread_buf.num_of_v_,
1336 c_thread_buf.s_per_v,
1340 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1341 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1343 if constexpr(IsInputGemm)
1345 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
1346 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1347 p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1351 decltype(b_grid_desc_bpreshuffled),
1352 decltype(b_block_desc_bk0_n_bk1),
1356 BBlockTransferSrcScalarPerVector,
1357 BThreadTransferSrcResetCoordinateAfterRun,
1358 true>(b_grid_desc_bpreshuffled,
1364 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1365 a_grid_desc_ak0_m_ak1,
1366 a_block_desc_ak0_m_ak1,
1370 a_block_slice_copy_step,
1371 b_grid_desc_bpreshuffled,
1373 b_blockwise_copy_up,
1377 b_block_slice_copy_step,
1380 num_k_block_main_loop);
1384 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1385 a_grid_desc_ak0_m_ak1,
1386 a_block_desc_ak0_m_ak1,
1390 a_block_slice_copy_step,
1391 b_grid_desc_bpreshuffled,
1395 b_block_slice_copy_step,
1397 num_k_block_main_loop);
1402 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1403 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1406 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1409 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1410 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1414 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1415 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1417 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1418 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1419 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1420 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1421 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1422 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1423 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1424 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1427 const float* p_sorted_weights_0 = p_ds_grid[
I0];
1428 const float* p_scale_b = p_ds_grid[
I1];
1430 static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
1431 static_assert(M4 == 4 || M4 == 8);
1435 if(p_sorted_weights_0 !=
nullptr && p_scale_b !=
nullptr)
1437 if constexpr(PerTokenQuant)
1439 constexpr
index_t scale_stride = (IsInputGemm ? 2 : 1);
1440 p_scale_b += expert_id * problem.
N * scale_stride + block_n_id * NPerBlock +
1445 p_scale_b += expert_id;
1451 const float scale_b = p_scale_b[n0 *
NWave * NPerXdl * PerTokenQuant];
1454 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1455 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1456 if constexpr(PerTokenQuant)
1459 *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
1460 p_sorted_token_ids + m_pos);
1462 if constexpr(MulRoutedWeight)
1464 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1465 p_ds_grid[
I2] + m_pos);
1468 float scale_a = [&]() {
1469 if constexpr(PerTokenQuant)
1472 scale_token_ids.template AsType<index_t>()[m4];
1473 const index_t token_offset = fused_token & 0xffffff;
1475 ? p_sorted_weights_0[IsInputGemm
1485 return p_sorted_weights_0[0];
1489 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1492 if constexpr(IsInputGemm)
1496 const float scale_up =
1497 p_scale_b[(n0 *
NWave * NPerXdl + problem.
N) *
1499 float gate = scale_a * scale_b * c_thread_buf[cidx];
1500 float up = scale_a * scale_up * c_thread_buf_up[cidx];
1501 if constexpr(MulRoutedWeight)
1503 gate = gate * topk_weights.template AsType<float>()[m4];
1504 up = up * topk_weights.template AsType<float>()[m4];
1512 c_thread_buf_fp32(cidx) = gate * up;
1516 const float scale_up =
1517 p_scale_b[(n0 *
NWave * NPerXdl + problem.
N) *
1519 float gate = scale_a * scale_b * c_thread_buf[cidx];
1520 float up = scale_a * scale_up * c_thread_buf_up[cidx];
1521 if constexpr(MulRoutedWeight)
1523 gate = gate * topk_weights.template AsType<float>()[m4];
1524 up = up * topk_weights.template AsType<float>()[m4];
1532 c_thread_buf_fp32(cidx) = gate * up;
1537 c_thread_buf_fp32(cidx) =
1538 scale_a * scale_b * c_thread_buf[cidx];
1539 if constexpr(MulRoutedWeight)
1541 c_thread_buf_fp32(cidx) =
1542 c_thread_buf_fp32(cidx) *
1543 topk_weights.template AsType<float>()[m4];
1557 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1558 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1559 if constexpr(MulRoutedWeight)
1561 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1562 p_ds_grid[
I2] + m_pos);
1566 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1570 if constexpr(IsInputGemm)
1574 float gate = c_thread_buf[cidx];
1575 float up = c_thread_buf_up[cidx];
1576 if constexpr(MulRoutedWeight)
1578 gate = gate * topk_weights.template AsType<float>()[m4];
1579 up = up * topk_weights.template AsType<float>()[m4];
1582 c_thread_buf_fp32(cidx) = gate * up;
1586 float gate = c_thread_buf[cidx];
1587 float up = c_thread_buf_up[cidx];
1588 if constexpr(MulRoutedWeight)
1590 gate = gate * topk_weights.template AsType<float>()[m4];
1591 up = up * topk_weights.template AsType<float>()[m4];
1594 c_thread_buf_fp32(cidx) = gate * up;
1599 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1600 if constexpr(MulRoutedWeight)
1602 c_thread_buf_fp32(cidx) =
1603 topk_weights.template AsType<float>()[m4] *
1604 c_thread_buf_fp32[cidx];
1613 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1616 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1617 static_cast<CShuffleDataType*
>(p_shared),
1618 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1621 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1641 const auto c_thread_mtx_on_block =
1642 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1644 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1645 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1647 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1653 const auto m_thread_data_on_block_idx =
1654 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1657 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1663 const auto n_thread_data_on_block_idx =
1664 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1668 auto c_thread_copy_vgpr_to_lds =
1671 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1672 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1674 Sequence<CShuffleMXdlPerWavePerShuffle,
1675 CShuffleNXdlPerWavePerShuffle,
1688 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1691 m_thread_data_on_block_idx[
I1],
1692 n_thread_data_on_block_idx[
I1],
1693 m_thread_data_on_block_idx[
I2],
1694 m_thread_data_on_block_idx[
I3],
1695 m_thread_data_on_block_idx[
I4],
1696 n_thread_data_on_block_idx[
I2]),
1699 using EDataType = CDataType;
1704 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1710 return make_dynamic_buffer<AddressSpaceEnum::Global>(
1711 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1717 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1719 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1724 tie(c_shuffle_block_buf),
1726 {
return ds_grid_buf[i]; },
1730 const auto idx_c_ds_block_begin =
1740 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1741 c_grid_desc_mblock_mperblock_nblock_nperblock;
1743 using CDEBlockTransferCluster =
1744 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1745 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1746 constexpr
index_t scatter_weight_idx = 3;
1751 decltype(c_ds_desc_refs),
1752 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1753 CElementwiseOperation,
1757 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1759 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
1760 CDEBlockTransferCluster,
1766 CDEShuffleBlockTransferScalarPerVectors,
1778 idx_c_ds_block_begin,
1779 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1783 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1784 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1785 constexpr
auto sfc_c_vgpr =
1788 Sequence<CShuffleMXdlPerWavePerShuffle,
1789 CShuffleNXdlPerWavePerShuffle,
1797 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1800 constexpr
auto sfc_cde_block =
1804 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1806 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
1808 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
1809 constexpr
auto EMThreads =
1810 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
1811 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1812 constexpr
auto ENThreads =
1813 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
1818 auto dstidx = sfc_cde_block.GetIndex(access_id);
1820 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
1822 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1823 IndexType token_offset = fused_token & 0xffffff;
1824 if constexpr(IsInputGemm)
1826 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1828 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
N;
1834 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1835 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1837 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1838 c_shuffle_block_buf);
1844 cde_block_copy_lds_and_global.Run(
1847 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1851 if constexpr(access_id < num_access - 1)
1853 constexpr
auto cde_lds_and_global_step =
1854 sfc_cde_block.GetForwardStep(access_id);
1858 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1859 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
1863 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1864 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1866 cde_lds_and_global_step);
1872 template <
bool HasMainKBlockLoop,
1876 const index_t* p_sorted_expert_ids,
1877 const index_t* p_max_token_id,
1878 const ADataType* p_a_grid,
1879 const BDataType* p_b_grid,
1881 CDataType* p_c_grid,
1885 AElementwiseOperation a_element_op,
1886 BElementwiseOperation b_element_op,
1887 CElementwiseOperation c_element_op)
1899 const auto b_grid_desc_bpreshuffled =
1901 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1907 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1910 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1912 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1913 if(expert_block_id * MPerBlock >= max_token_id)
1916 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1917 const auto block_mn = [&]() -> std::pair<int, int> {
1918 if constexpr(NSwizzle)
1920 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1922 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1923 const index_t expert_swizzle =
1924 ecnt > 0 ? ecnt : 1;
1925 const index_t bid_new = blockIdx.x - prefix_block;
1926 const index_t nid = __builtin_amdgcn_readfirstlane(
1927 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1929 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1934 return {blockIdx.x, blockIdx.y};
1938 const index_t block_n_id = block_mn.first;
1939 const index_t block_m_id = block_mn.second;
1941 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1944 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1945 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1946 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1947 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1948 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1949 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1951 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
1955 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1956 index_t token_offset = fused_token & 0xffffff;
1957 if constexpr(!IsInputGemm)
1959 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1961 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
1963 const IndexType expert_stride =
1964 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
1965 const IndexType expert_offset = expert_id * expert_stride /
BPackedSize;
1967 const index_t n_block_data_idx_on_grid =
1968 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1970 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1971 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1972 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1973 p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1984 AElementwiseOperation,
1988 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1989 ABlockTransferThreadClusterArrangeOrder,
1992 decltype(a_grid_desc_ak0_m_ak1),
1993 decltype(a_block_desc_ak0_m_ak1),
1994 ABlockTransferSrcAccessOrder,
1996 ABlockTransferSrcVectorDim,
1998 ABlockTransferSrcScalarPerVector,
1999 ABlockTransferDstScalarPerVector_AK1,
2002 AThreadTransferSrcResetCoordinateAfterRun,
2006 2>(a_grid_desc_ak0_m_ak1,
2009 a_block_desc_ak0_m_ak1,
2016 auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2017 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2018 auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2019 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2020 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2025 decltype(b_grid_desc_bpreshuffled),
2026 decltype(b_block_desc_bk0_n_bk1),
2030 BBlockTransferSrcScalarPerVector,
2031 BThreadTransferSrcResetCoordinateAfterRun,
2032 true>(b_grid_desc_bpreshuffled,
2040 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2041 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2042 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2043 static_cast<ADataType*
>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2044 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2050 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2052 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2053 decltype(c_thread_buf) c_thread_buf_up;
2057 c_thread_buf.num_of_v_,
2058 c_thread_buf.s_per_v,
2062 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2063 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2066 if constexpr(IsInputGemm)
2068 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
2069 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2070 p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2074 decltype(b_grid_desc_bpreshuffled),
2075 decltype(b_block_desc_bk0_n_bk1),
2079 BBlockTransferSrcScalarPerVector,
2080 BThreadTransferSrcResetCoordinateAfterRun,
2081 true>(b_grid_desc_bpreshuffled,
2086 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2087 a_grid_desc_ak0_m_ak1,
2088 a_block_desc_ak0_m_ak1,
2092 a_block_slice_copy_step,
2093 b_grid_desc_bpreshuffled,
2095 b_blockwise_copy_up,
2099 b_block_slice_copy_step,
2102 num_k_block_main_loop);
2107 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2108 a_grid_desc_ak0_m_ak1,
2109 a_block_desc_ak0_m_ak1,
2113 a_block_slice_copy_step,
2114 b_grid_desc_bpreshuffled,
2118 b_block_slice_copy_step,
2120 num_k_block_main_loop);
2125 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2126 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2129 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2132 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2133 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2137 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2138 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2140 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2141 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2142 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2143 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2144 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2145 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2146 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2147 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2150 const float* p_sorted_weights_0 = p_ds_grid[
I0];
2151 const float* p_scale_b = p_ds_grid[
I1];
2153 static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2154 static_assert(M4 == 4 || M4 == 8);
2158 if(p_sorted_weights_0 !=
nullptr && p_scale_b !=
nullptr)
2160 if constexpr(PerTokenQuant)
2162 constexpr
index_t scale_stride = (IsInputGemm ? 2 : 1);
2163 p_scale_b += expert_id * problem.
N * scale_stride + block_n_id * NPerBlock +
2168 p_scale_b += expert_id;
2174 const float scale_b = p_scale_b[n0 *
NWave * NPerXdl * PerTokenQuant];
2177 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2178 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2179 if constexpr(PerTokenQuant)
2182 *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
2183 p_sorted_token_ids + m_pos);
2185 if constexpr(MulRoutedWeight)
2187 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2188 p_ds_grid[
I2] + m_pos);
2191 float scale_a = [&]() {
2192 if constexpr(PerTokenQuant)
2195 scale_token_ids.template AsType<index_t>()[m4];
2196 const index_t token_offset = fused_token & 0xffffff;
2198 ? p_sorted_weights_0[IsInputGemm
2208 return p_sorted_weights_0[0];
2212 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2215 if constexpr(IsInputGemm)
2219 const float scale_up =
2220 p_scale_b[(n0 *
NWave * NPerXdl + problem.
N) *
2222 float gate = scale_a * scale_b * c_thread_buf[cidx];
2223 float up = scale_a * scale_up * c_thread_buf_up[cidx];
2224 if constexpr(MulRoutedWeight)
2226 gate = gate * topk_weights.template AsType<float>()[m4];
2227 up = up * topk_weights.template AsType<float>()[m4];
2235 c_thread_buf_fp32(cidx) = gate * up;
2239 const float scale_up =
2240 p_scale_b[(n0 *
NWave * NPerXdl + problem.
N) *
2242 float gate = scale_a * scale_b * c_thread_buf[cidx];
2243 float up = scale_a * scale_up * c_thread_buf_up[cidx];
2244 if constexpr(MulRoutedWeight)
2246 gate = gate * topk_weights.template AsType<float>()[m4];
2247 up = up * topk_weights.template AsType<float>()[m4];
2255 c_thread_buf_fp32(cidx) = gate * up;
2260 c_thread_buf_fp32(cidx) =
2261 scale_a * scale_b * c_thread_buf[cidx];
2262 if constexpr(MulRoutedWeight)
2264 c_thread_buf_fp32(cidx) =
2265 c_thread_buf_fp32(cidx) *
2266 topk_weights.template AsType<float>()[m4];
2280 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2281 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2282 if constexpr(MulRoutedWeight)
2284 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2285 p_ds_grid[
I2] + m_pos);
2289 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2293 if constexpr(IsInputGemm)
2297 float gate = c_thread_buf[cidx];
2298 float up = c_thread_buf_up[cidx];
2299 if constexpr(MulRoutedWeight)
2301 gate = gate * topk_weights.template AsType<float>()[m4];
2302 up = up * topk_weights.template AsType<float>()[m4];
2305 c_thread_buf_fp32(cidx) = gate * up;
2309 float gate = c_thread_buf[cidx];
2310 float up = c_thread_buf_up[cidx];
2311 if constexpr(MulRoutedWeight)
2313 gate = gate * topk_weights.template AsType<float>()[m4];
2314 up = up * topk_weights.template AsType<float>()[m4];
2317 c_thread_buf_fp32(cidx) = gate * up;
2322 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2323 if constexpr(MulRoutedWeight)
2325 c_thread_buf_fp32(cidx) =
2326 topk_weights.template AsType<float>()[m4] *
2327 c_thread_buf_fp32[cidx];
2336 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2339 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2340 static_cast<CShuffleDataType*
>(p_shared),
2341 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2344 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2364 const auto c_thread_mtx_on_block =
2365 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2367 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2368 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2370 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2376 const auto m_thread_data_on_block_idx =
2377 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2380 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2386 const auto n_thread_data_on_block_idx =
2387 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2391 auto c_thread_copy_vgpr_to_lds =
2394 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2395 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2397 Sequence<CShuffleMXdlPerWavePerShuffle,
2398 CShuffleNXdlPerWavePerShuffle,
2411 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2414 m_thread_data_on_block_idx[
I1],
2415 n_thread_data_on_block_idx[
I1],
2416 m_thread_data_on_block_idx[
I2],
2417 m_thread_data_on_block_idx[
I3],
2418 m_thread_data_on_block_idx[
I4],
2419 n_thread_data_on_block_idx[
I2]),
2422 using EDataType = CDataType;
2427 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2433 return make_dynamic_buffer<AddressSpaceEnum::Global>(
2434 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2440 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2442 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2447 tie(c_shuffle_block_buf),
2449 {
return ds_grid_buf[i]; },
2453 const auto idx_c_ds_block_begin =
2463 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2464 c_grid_desc_mblock_mperblock_nblock_nperblock;
2466 using CDEBlockTransferCluster =
2467 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2468 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2469 constexpr
index_t scatter_weight_idx = 3;
2474 decltype(c_ds_desc_refs),
2475 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2476 CElementwiseOperation,
2480 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2482 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
2483 CDEBlockTransferCluster,
2489 CDEShuffleBlockTransferScalarPerVectors,
2501 idx_c_ds_block_begin,
2502 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2506 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2507 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2508 constexpr
auto sfc_c_vgpr =
2511 Sequence<CShuffleMXdlPerWavePerShuffle,
2512 CShuffleNXdlPerWavePerShuffle,
2520 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2523 constexpr
auto sfc_cde_block =
2527 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2529 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
2531 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2532 constexpr
auto EMThreads =
2533 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2534 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2535 constexpr
auto ENThreads =
2536 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2541 auto dstidx = sfc_cde_block.GetIndex(access_id);
2543 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2545 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2546 IndexType token_offset = fused_token & 0xffffff;
2547 if constexpr(IsInputGemm)
2549 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2551 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
N;
2557 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2558 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2560 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2561 c_shuffle_block_buf);
2567 cde_block_copy_lds_and_global.Run(
2570 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2574 if constexpr(access_id < num_access - 1)
2576 constexpr
auto cde_lds_and_global_step =
2577 sfc_cde_block.GetForwardStep(access_id);
2581 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2582 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2586 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2587 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2589 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
constexpr auto BlockGemmBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp:41
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
Definition: gridwise_moe_gemm.hpp:659
const BDataType * p_b_grid
Definition: gridwise_moe_gemm.hpp:715
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm.hpp:711
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm.hpp:712
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm.hpp:719
const ADataType * p_a_grid
Definition: gridwise_moe_gemm.hpp:714
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm.hpp:660
const index_t * p_max_token_id
Definition: gridwise_moe_gemm.hpp:713
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm.hpp:720
CDataType * p_c_grid
Definition: gridwise_moe_gemm.hpp:717
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm.hpp:716
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm.hpp:721
Definition: gridwise_moe_gemm.hpp:594
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm.hpp:644
index_t NumTokens
Definition: gridwise_moe_gemm.hpp:637
index_t MBlock
Definition: gridwise_moe_gemm.hpp:653
index_t TopK
Definition: gridwise_moe_gemm.hpp:638
index_t K
Definition: gridwise_moe_gemm.hpp:641
__host__ void Print() const
Definition: gridwise_moe_gemm.hpp:626
index_t NPadded
Definition: gridwise_moe_gemm.hpp:648
index_t BK0
Definition: gridwise_moe_gemm.hpp:652
index_t KRead
Definition: gridwise_moe_gemm.hpp:649
index_t MPadded
Definition: gridwise_moe_gemm.hpp:647
index_t AK0
Definition: gridwise_moe_gemm.hpp:651
index_t StrideA
Definition: gridwise_moe_gemm.hpp:642
index_t StrideC
Definition: gridwise_moe_gemm.hpp:645
index_t M
Definition: gridwise_moe_gemm.hpp:639
index_t KBatch
Definition: gridwise_moe_gemm.hpp:646
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm.hpp:595
index_t KPadded
Definition: gridwise_moe_gemm.hpp:650
index_t StrideB
Definition: gridwise_moe_gemm.hpp:643
index_t N
Definition: gridwise_moe_gemm.hpp:640
index_t NBlock
Definition: gridwise_moe_gemm.hpp:654
Definition: gridwise_moe_gemm.hpp:725
index_t a_k_split_offset
Definition: gridwise_moe_gemm.hpp:757
index_t b_k_split_offset
Definition: gridwise_moe_gemm.hpp:758
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm.hpp:726
Definition: gridwise_moe_gemm.hpp:171
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm.hpp:246
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:298
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm.hpp:217
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:292
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm.hpp:210
remove_cvref_t< decltype(BlockGemmBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm.hpp:926
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm.hpp:261
static constexpr index_t NLane
Definition: gridwise_moe_gemm.hpp:212
static constexpr auto I5
Definition: gridwise_moe_gemm.hpp:177
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm.hpp:549
static constexpr auto BK0Number
Definition: gridwise_moe_gemm.hpp:185
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm.hpp:330
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm.hpp:190
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm.hpp:1132
static constexpr auto I2
Definition: gridwise_moe_gemm.hpp:174
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm.hpp:232
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm.hpp:305
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm.hpp:230
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm.hpp:412
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm.hpp:422
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm.hpp:519
static constexpr auto I6
Definition: gridwise_moe_gemm.hpp:178
static constexpr auto I0
Definition: gridwise_moe_gemm.hpp:172
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm.hpp:215
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm.hpp:1125
static constexpr auto I1
Definition: gridwise_moe_gemm.hpp:173
static constexpr auto I4
Definition: gridwise_moe_gemm.hpp:176
static constexpr auto AK1Number
Definition: gridwise_moe_gemm.hpp:186
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:280
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm.hpp:310
static constexpr auto BK1Number
Definition: gridwise_moe_gemm.hpp:187
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm.hpp:188
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm.hpp:239
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm.hpp:525
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm.hpp:270
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm.hpp:228
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1875
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:286
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm.hpp:570
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm.hpp:952
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm.hpp:887
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm.hpp:510
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm.hpp:181
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm.hpp:256
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:1140
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm.hpp:928
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm.hpp:880
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm.hpp:266
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1161
static constexpr index_t KPack
Definition: gridwise_moe_gemm.hpp:193
static constexpr index_t NWave
Definition: gridwise_moe_gemm.hpp:213
static constexpr auto I3
Definition: gridwise_moe_gemm.hpp:175
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm.hpp:275
static constexpr auto AK0Number
Definition: gridwise_moe_gemm.hpp:184
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:582
static constexpr index_t KGroup
Definition: gridwise_moe_gemm.hpp:198
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm.hpp:316
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm.hpp:761
static constexpr index_t KLane
Definition: gridwise_moe_gemm.hpp:195
static constexpr auto I7
Definition: gridwise_moe_gemm.hpp:179
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1810
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1804
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: amd_ck_fp8.hpp:36
Definition: integral_constant.hpp:20
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10