36 template <
typename GridwiseGemm,
37 bool HasMainKBlockLoop,
42 #if CK_USE_LAUNCH_BOUNDS
49 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
51 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
53 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
54 karg.p_sorted_token_ids,
55 karg.p_sorted_expert_ids,
57 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
58 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
73 template <
typename GridwiseGemm,
74 bool HasMainKBlockLoop,
79 #if CK_USE_LAUNCH_BOUNDS
86 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
87 __shared__
char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
89 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
91 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
92 karg.p_sorted_token_ids,
93 karg.p_sorted_expert_ids,
95 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
96 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
112 template <
typename ALayout,
118 typename AccDataType,
119 typename CShuffleDataType,
122 typename AElementwiseOperation,
123 typename BElementwiseOperation,
124 typename CElementwiseOperation,
139 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
140 typename ABlockTransferThreadClusterArrangeOrder,
141 typename ABlockTransferSrcAccessOrder,
142 index_t ABlockTransferSrcVectorDim,
143 index_t ABlockTransferSrcScalarPerVector,
144 index_t ABlockTransferDstScalarPerVector_AK1,
145 bool AThreadTransferSrcResetCoordinateAfterRun,
147 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
148 typename BBlockTransferThreadClusterArrangeOrder,
149 typename BBlockTransferSrcAccessOrder,
150 index_t BBlockTransferSrcVectorDim,
151 index_t BBlockTransferSrcScalarPerVector,
152 index_t BBlockTransferDstScalarPerVector_BK1,
153 bool BThreadTransferSrcResetCoordinateAfterRun,
155 index_t CShuffleMXdlPerWavePerShuffle,
156 index_t CShuffleNXdlPerWavePerShuffle,
157 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
158 typename CDEShuffleBlockTransferScalarPerVectors,
161 index_t ActivationOperation = 0,
162 bool NSwizzle =
false,
163 bool IsInputGemm =
true,
164 bool MulRoutedWeight =
true,
166 typename ComputeTypeA = CDataType,
167 typename ComputeTypeB = ComputeTypeA,
168 typename LDSTypeA = ADataType,
169 typename LDSTypeB = BDataType>
185 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
223 return static_cast<const DDataType*
>(
nullptr);
250 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
251 const index_t gridy = NSwizzle ? 1 : mblock;
281 auto K_t = K_Batch * KPerBlock;
282 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
287 auto K_t = K_Batch * KPerBlock;
288 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
293 auto K_t = K_Batch * KPerBlock;
294 return (K + K_t - 1) / K_t * KPerBlock;
300 auto K_t = K_Batch * KReadVec;
301 return (K + K_t - 1) / K_t * KReadVec;
314 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
330 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
332 const auto a_grid_desc_mraw_kraw = [&]() {
333 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
337 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
345 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
346 GemmSpec == GemmSpecialization::MNKPadding)
349 const auto a_grid_desc_m_k =
363 return a_grid_desc_ak0_m_ak1;
365 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
366 GemmSpec == GemmSpecialization::MNPadding)
370 a_grid_desc_mraw_kraw,
376 return a_grid_desc_ak0_m_ak1;
378 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
379 GemmSpec == GemmSpecialization::NKPadding)
383 a_grid_desc_mraw_kraw,
395 return a_grid_desc_ak0_m_ak1;
401 a_grid_desc_mraw_kraw,
407 return a_grid_desc_ak0_m_ak1;
413 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
414 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
418 make_tuple(
NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber,
I1));
424 const auto b_grid_desc_nraw_kraw = [&]() {
438 GemmSpec != GemmSpecialization::Default),
439 "pk_i4_t does not support padding");
441 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
442 GemmSpec == GemmSpecialization::MNKPadding)
445 const auto b_grid_desc_n_k =
459 return b_grid_desc_bk0_n_bk1;
461 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
462 GemmSpec == GemmSpecialization::MNPadding)
466 b_grid_desc_nraw_kraw,
472 return b_grid_desc_bk0_n_bk1;
474 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
475 GemmSpec == GemmSpecialization::MKPadding)
479 b_grid_desc_nraw_kraw,
491 return b_grid_desc_bk0_n_bk1;
497 b_grid_desc_nraw_kraw,
503 return b_grid_desc_bk0_n_bk1;
507 template <
typename ABlockDesc_AK0_M_AK1>
508 __host__ __device__
static constexpr
auto
511 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
513 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
516 template <
typename BBlockDesc_BK0_N_BK1>
517 __host__ __device__
static constexpr
auto
520 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
523 template <
typename ELayout>
525 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
527 const auto c_grid_desc_mraw_nraw = [&]() {
546 template <
typename DLayout>
547 __host__ __device__
static auto
550 const auto c_grid_desc_mraw_nraw = [&]() {
575 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
580 template <
typename DsGr
idDesc>
582 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
587 ds_grid_desc_m_n[i], MBlock, NBlock);
603 std::array<index_t, NumDTensor> StrideDs_,
629 std::cout <<
"problem {" <<
"NumTokens:" <<
NumTokens <<
", " <<
"TopK:" <<
TopK <<
", "
630 <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
633 <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", " <<
"AK0:" <<
AK0
634 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock <<
", "
635 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
662 const index_t* p_sorted_expert_ids_,
663 const index_t* p_max_token_id_,
664 const ADataType* p_a_grid_,
665 const BDataType* p_b_grid_,
666 std::array<const void*, NumDTensor> p_ds_grid_,
667 CDataType* p_c_grid_,
675 std::array<index_t, NumDTensor> StrideDs_,
680 AElementwiseOperation a_element_op_,
681 BElementwiseOperation b_element_op_,
682 CElementwiseOperation c_element_op_)
712 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
736 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
740 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
745 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
749 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
755 if(k_id < karg.
KBatch - 1)
771 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
772 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
774 if constexpr(ABlockLdsExtraM)
784 constexpr
auto a_lds_block_desc =
796 return a_lds_block_desc_permuted;
803 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
804 constexpr
auto M1 = MPerBlock / M0;
806 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
807 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
808 constexpr
auto KThreadRead = WaveSize / MPerXdl;
809 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
811 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(LDSTypeA) > 128)
813 : 128 / (
AK1Number * M0 *
sizeof(LDSTypeA));
814 constexpr
auto KThreadReadPerm =
815 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
816 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
820 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(LDSTypeA) > 128)
822 : ((128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA))) > M0
824 : 128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA)));
830 Number<kfold * M0 / mpair>{},
849 a_lds_block_desc_permuted,
871 a_lds_block_desc_unmerged,
874 Number<KThreadWrite / kfold / KThreadReadPerm>{},
883 return a_lds_block_desc_ak0_m_ak1;
896 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
898 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
905 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
923 ABlockTransferSrcScalarPerVector,
924 BBlockTransferSrcScalarPerVector,
946 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
949 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
952 constexpr
auto c_block_size =
953 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
956 c_block_size *
sizeof(CShuffleDataType));
962 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
963 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
964 "Invalid tuning param!");
972 if(!(karg.
M % MPerBlock == 0))
975 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M <<
" "
976 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
990 if(!(karg.
N % NPerBlock == 0))
993 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N <<
" "
994 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1008 auto K_t = karg.
KBatch * KPerBlock;
1009 if(!(karg.
K % K_t == 0))
1012 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1013 << karg.
K <<
" " << __FILE__ <<
":" << __LINE__
1014 <<
", in function: " << __func__ << std::endl;
1023 auto K_t = karg.
KBatch * KReadVec;
1025 if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
1033 if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
1036 std::cout <<
"Arg K (" << karg.
K
1037 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1038 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1039 << __LINE__ <<
", in function: " << __func__ << std::endl;
1047 if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
1050 std::cout <<
"Arg M (" << karg.
M
1051 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1052 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1053 << __LINE__ <<
", in function: " << __func__ << std::endl;
1062 if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
1065 std::cout <<
"Arg N (" << karg.
N
1066 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1067 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1068 << __LINE__ <<
", in function: " << __func__ << std::endl;
1076 if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
1079 std::cout <<
"Arg K (" << karg.
K
1080 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1081 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1082 << __LINE__ <<
", in function: " << __func__ << std::endl;
1094 std::cout <<
"Arg N (" << karg.
N
1095 <<
") value is not a multiple of "
1096 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1098 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1109 std::cout <<
"Arg M (" << karg.
M
1110 <<
") value is not a multiple of "
1111 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1113 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1122 const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
1124 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1135 const index_t num_loop = K / KPerBlock;
1137 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1142 const index_t num_loop = K / KPerBlock;
1144 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1147 template <
typename CGr
idDesc>
1149 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1158 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1166 template <
bool HasMainKBlockLoop,
1170 const index_t* p_sorted_expert_ids,
1171 const index_t* p_max_token_id,
1172 const ADataType* p_a_grid,
1173 const BDataType* p_b_grid,
1175 CDataType* p_c_grid,
1180 AElementwiseOperation a_element_op,
1181 BElementwiseOperation b_element_op,
1182 CElementwiseOperation c_element_op)
1194 const auto b_grid_desc_bpreshuffled =
1196 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1214 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1217 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1219 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1220 if(expert_block_id * MPerBlock >= max_token_id)
1223 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1224 const auto block_mn = [&]() -> std::pair<int, int> {
1225 if constexpr(NSwizzle)
1227 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1229 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1230 const index_t expert_swizzle =
1231 ecnt > 0 ? ecnt : 1;
1232 const index_t bid_new = blockIdx.x - prefix_block;
1233 const index_t nid = __builtin_amdgcn_readfirstlane(
1234 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1236 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1241 return {blockIdx.x, blockIdx.y};
1244 const index_t block_n_id = block_mn.first;
1245 const index_t block_m_id = block_mn.second;
1247 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1250 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1251 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1252 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1253 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1254 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1255 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1257 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
1261 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1262 index_t token_offset = fused_token & 0xffffff;
1263 if constexpr(!IsInputGemm)
1265 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1267 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
1270 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
1271 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1276 const index_t n_block_data_idx_on_grid =
1277 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1279 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1280 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1281 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1283 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1285 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1286 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1287 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1288 p_b_scale_grid + expert_id * expert_scale_stride,
1289 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1300 AElementwiseOperation,
1304 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1305 ABlockTransferThreadClusterArrangeOrder,
1308 decltype(a_grid_desc_ak0_m_ak1),
1309 decltype(a_block_desc_ak0_m_ak1),
1310 ABlockTransferSrcAccessOrder,
1312 ABlockTransferSrcVectorDim,
1314 ABlockTransferSrcScalarPerVector,
1315 ABlockTransferDstScalarPerVector_AK1,
1318 AThreadTransferSrcResetCoordinateAfterRun,
1322 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1325 a_block_desc_ak0_m_ak1,
1332 auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1333 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1338 decltype(b_grid_desc_bpreshuffled),
1339 decltype(b_block_desc_bk0_n_bk1),
1343 BBlockTransferSrcScalarPerVector,
1344 BThreadTransferSrcResetCoordinateAfterRun,
1345 true>(b_grid_desc_bpreshuffled,
1353 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1354 static_cast<LDSTypeA*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1360 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1362 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1363 decltype(c_thread_buf) c_thread_buf_up;
1365 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1366 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1369 constexpr
index_t ScaleSliceSizeM = MXdlPerWave;
1378 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1379 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1380 constexpr
index_t WaveSize = BlockSize / (MWaves * NWaves);
1392 const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
1394 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
1399 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
1400 index_t token_offset = fused_token & 0xffffff;
1401 if constexpr(!IsInputGemm)
1403 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1405 scale_gather_offsets(m0) =
1409 auto a_scale_thread_copy =
1412 decltype(a_scale_grid_desc_am_ak),
1413 decltype(a_scale_thread_desc),
1423 auto b_scale_thread_copy =
1426 decltype(b_scale_grid_desc_bn_ak),
1427 decltype(b_scale_thread_desc),
1434 b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1437 constexpr
auto a_scale_thread_slice_copy_step =
1439 constexpr
auto b_scale_thread_slice_copy_step =
make_multi_index(0, ScaleSliceSizeK);
1442 if constexpr(IsInputGemm)
1444 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
1445 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1447 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1451 decltype(b_grid_desc_bpreshuffled),
1452 decltype(b_block_desc_bk0_n_bk1),
1456 BBlockTransferSrcScalarPerVector,
1457 BThreadTransferSrcResetCoordinateAfterRun,
1458 true>(b_grid_desc_bpreshuffled,
1464 p_b_scale_grid + expert_scale_stride / 2 /
BPackedSize;
1465 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1466 p_b_scale_grid_up + expert_id * expert_scale_stride,
1467 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1468 auto b_scale_thread_copy_up =
1471 decltype(b_scale_grid_desc_bn_ak),
1472 decltype(b_scale_thread_desc),
1479 b_scale_grid_desc_bn_ak,
1482 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1483 a_grid_desc_ak0_m_ak1,
1484 a_block_desc_ak0_m_ak1,
1488 a_block_slice_copy_step,
1490 b_grid_desc_bpreshuffled,
1491 b_block_desc_bk0_n_bk1,
1493 b_blockwise_copy_up,
1497 b_block_slice_copy_step,
1499 c_scale_thread_desc,
1503 a_scale_grid_desc_am_ak,
1504 a_scale_thread_desc,
1505 a_scale_thread_copy,
1507 a_scale_thread_slice_copy_step,
1509 b_scale_grid_desc_bn_ak,
1510 b_scale_thread_desc,
1511 b_scale_thread_copy,
1512 b_scale_thread_copy_up,
1514 b_scale_grid_buf_up,
1515 b_scale_thread_slice_copy_step,
1517 num_k_block_main_loop);
1521 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1522 a_grid_desc_ak0_m_ak1,
1523 a_block_desc_ak0_m_ak1,
1527 a_block_slice_copy_step,
1529 b_grid_desc_bpreshuffled,
1530 b_block_desc_bk0_n_bk1,
1534 b_block_slice_copy_step,
1536 c_scale_thread_desc,
1539 a_scale_grid_desc_am_ak,
1540 a_scale_thread_desc,
1541 a_scale_thread_copy,
1543 a_scale_thread_slice_copy_step,
1545 b_scale_grid_desc_bn_ak,
1546 b_scale_thread_desc,
1547 b_scale_thread_copy,
1549 b_scale_thread_slice_copy_step,
1551 num_k_block_main_loop);
1556 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1557 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1560 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1564 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1565 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1569 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1570 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1572 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I0);
1573 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I1);
1574 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I2);
1575 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I3);
1576 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I4);
1577 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I5);
1578 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I6);
1579 constexpr
auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I7);
1581 static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
1582 static_assert(M0 * M1 * M2 == MPerBlock);
1583 static_assert(N4 == 4 || N4 == 8);
1590 if constexpr(MulRoutedWeight)
1592 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
1593 topk_weight = p_ds_grid[
I0][m_pos];
1598 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1601 if constexpr(IsInputGemm)
1605 float gate = c_thread_buf[cidx];
1606 float up = c_thread_buf_up[cidx];
1607 if constexpr(MulRoutedWeight)
1609 gate = gate * topk_weight;
1610 up = up * topk_weight;
1618 c_thread_buf(cidx) = gate * up;
1622 float gate = c_thread_buf[cidx];
1623 float up = c_thread_buf_up[cidx];
1624 if constexpr(MulRoutedWeight)
1626 gate = gate * topk_weight;
1627 up = up * topk_weight;
1635 c_thread_buf(cidx) = gate * up;
1640 if constexpr(MulRoutedWeight)
1642 c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
1650 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1653 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1654 static_cast<CShuffleDataType*
>(p_shared),
1655 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1658 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1678 const auto c_thread_mtx_on_block =
1679 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1681 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1682 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1684 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1690 const auto m_thread_data_on_block_idx =
1691 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1694 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1700 const auto n_thread_data_on_block_idx =
1701 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1705 auto c_thread_copy_vgpr_to_lds =
1708 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1709 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1711 Sequence<CShuffleMXdlPerWavePerShuffle,
1712 CShuffleNXdlPerWavePerShuffle,
1725 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1728 m_thread_data_on_block_idx[
I1],
1729 n_thread_data_on_block_idx[
I1],
1730 m_thread_data_on_block_idx[
I2],
1731 n_thread_data_on_block_idx[
I2],
1732 n_thread_data_on_block_idx[
I3],
1733 n_thread_data_on_block_idx[
I4]),
1736 using EDataType = CDataType;
1741 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1748 const DDataType* ptr_ = p_ds_grid[i];
1751 return make_dynamic_buffer<AddressSpaceEnum::Global>(
1752 ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
1758 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1760 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1765 tie(c_shuffle_block_buf),
1767 {
return ds_grid_buf[i]; },
1771 const auto idx_c_ds_block_begin =
1781 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1782 c_grid_desc_mblock_mperblock_nblock_nperblock;
1784 using CDEBlockTransferCluster =
1785 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1786 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1787 constexpr
index_t scatter_weight_idx = IsInputGemm ? 1 : 1;
1792 decltype(c_ds_desc_refs),
1793 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1794 CElementwiseOperation,
1798 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1800 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
1801 CDEBlockTransferCluster,
1807 CDEShuffleBlockTransferScalarPerVectors,
1819 idx_c_ds_block_begin,
1820 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1824 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1825 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1827 constexpr
auto sfc_c_vgpr =
1830 Sequence<CShuffleMXdlPerWavePerShuffle,
1831 CShuffleNXdlPerWavePerShuffle,
1839 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1842 constexpr
auto sfc_cde_block =
1846 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1848 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
1850 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
1851 constexpr
auto EMThreads =
1852 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
1853 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1854 constexpr
auto ENThreads =
1855 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
1860 auto dstidx = sfc_cde_block.GetIndex(access_id);
1862 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
1864 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1865 index_t token_offset = fused_token & 0xffffff;
1866 if constexpr(IsInputGemm)
1868 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1870 scatter_offsets(m0) = token_offset * problem.
N;
1876 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1877 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1879 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1880 c_shuffle_block_buf);
1886 cde_block_copy_lds_and_global.Run(
1889 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1893 if constexpr(access_id < num_access - 1)
1895 constexpr
auto cde_lds_and_global_step =
1896 sfc_cde_block.GetForwardStep(access_id);
1900 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1901 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
1905 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1906 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1908 cde_lds_and_global_step);
1914 template <
bool HasMainKBlockLoop,
1918 const index_t* p_sorted_expert_ids,
1919 const index_t* p_max_token_id,
1920 const ADataType* p_a_grid,
1921 const BDataType* p_b_grid,
1923 CDataType* p_c_grid,
1929 AElementwiseOperation a_element_op,
1930 BElementwiseOperation b_element_op,
1931 CElementwiseOperation c_element_op)
1943 const auto b_grid_desc_bpreshuffled =
1945 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1962 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1965 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1966 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1967 if(expert_block_id * MPerBlock >= max_token_id)
1970 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1971 const auto block_mn = [&]() -> std::pair<int, int> {
1972 if constexpr(NSwizzle)
1974 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1976 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1977 const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
1978 const index_t bid_new = blockIdx.x - prefix_block;
1979 const index_t nid = __builtin_amdgcn_readfirstlane(
1980 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1982 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1987 return {blockIdx.x, blockIdx.y};
1990 const index_t block_n_id = block_mn.first;
1991 const index_t block_m_id = block_mn.second;
1994 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1997 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1998 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1999 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
2000 constexpr
auto AKThreads = AK0Threads * AK1Threads;
2001 constexpr
auto AMRepeats = MPerBlock / AMThreads;
2002 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2004 if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
2010 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2011 index_t token_offset = fused_token & 0xffffff;
2012 if constexpr(!IsInputGemm)
2014 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2016 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
2019 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
2020 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2024 const index_t n_block_data_idx_on_grid =
2025 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2027 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2028 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2029 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2031 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2033 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2034 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2035 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2036 p_b_scale_grid + expert_id * expert_scale_stride,
2037 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2048 AElementwiseOperation,
2052 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2053 ABlockTransferThreadClusterArrangeOrder,
2056 decltype(a_grid_desc_ak0_m_ak1),
2057 decltype(a_block_desc_ak0_m_ak1),
2058 ABlockTransferSrcAccessOrder,
2060 ABlockTransferSrcVectorDim,
2062 ABlockTransferSrcScalarPerVector,
2063 ABlockTransferDstScalarPerVector_AK1,
2066 AThreadTransferSrcResetCoordinateAfterRun,
2070 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2073 a_block_desc_ak0_m_ak1,
2080 auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2081 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2082 auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2083 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2084 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2089 decltype(b_grid_desc_bpreshuffled),
2090 decltype(b_block_desc_bk0_n_bk1),
2094 BBlockTransferSrcScalarPerVector,
2095 BThreadTransferSrcResetCoordinateAfterRun,
2096 true>(b_grid_desc_bpreshuffled,
2104 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2105 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2106 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2107 static_cast<ADataType*
>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2108 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2114 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2116 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2117 decltype(c_thread_buf) c_thread_buf_up;
2119 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2120 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2124 constexpr
index_t ScaleSliceSizeM = MXdlPerWave;
2133 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
2134 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
2135 constexpr
index_t WaveSize = BlockSize / (MWaves * NWaves);
2147 const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
2149 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
2154 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
2155 index_t token_offset = fused_token & 0xffffff;
2156 if constexpr(!IsInputGemm)
2158 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2160 scale_gather_offsets(m0) =
static_cast<IndexType
>(token_offset) *
2164 auto a_scale_thread_copy =
2167 decltype(a_scale_grid_desc_am_ak),
2168 decltype(a_scale_thread_desc),
2178 auto b_scale_thread_copy =
2181 decltype(b_scale_grid_desc_bn_ak),
2182 decltype(b_scale_thread_desc),
2189 b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2192 constexpr
auto a_scale_thread_slice_copy_step =
2194 constexpr
auto b_scale_thread_slice_copy_step =
make_multi_index(0, ScaleSliceSizeK);
2197 if constexpr(IsInputGemm)
2199 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
2200 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2202 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2206 decltype(b_grid_desc_bpreshuffled),
2207 decltype(b_block_desc_bk0_n_bk1),
2211 BBlockTransferSrcScalarPerVector,
2212 BThreadTransferSrcResetCoordinateAfterRun,
2213 true>(b_grid_desc_bpreshuffled,
2219 p_b_scale_grid + expert_scale_stride / 2 /
BPackedSize;
2220 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2221 p_b_scale_grid_up + expert_id * expert_scale_stride /
BPackedSize,
2222 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2223 auto b_scale_thread_copy_up =
2226 decltype(b_scale_grid_desc_bn_ak),
2227 decltype(b_scale_thread_desc),
2234 b_scale_grid_desc_bn_ak,
2237 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2238 a_grid_desc_ak0_m_ak1,
2239 a_block_desc_ak0_m_ak1,
2243 a_block_slice_copy_step,
2244 b_grid_desc_bpreshuffled,
2245 b_block_desc_bk0_n_bk1,
2247 b_blockwise_copy_up,
2251 b_block_slice_copy_step,
2252 c_scale_thread_desc,
2255 a_scale_grid_desc_am_ak,
2256 a_scale_thread_desc,
2257 a_scale_thread_copy,
2259 a_scale_thread_slice_copy_step,
2260 b_scale_grid_desc_bn_ak,
2261 b_scale_thread_desc,
2262 b_scale_thread_copy,
2263 b_scale_thread_copy_up,
2265 b_scale_grid_buf_up,
2266 b_scale_thread_slice_copy_step,
2267 num_k_block_main_loop);
2271 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2272 a_grid_desc_ak0_m_ak1,
2273 a_block_desc_ak0_m_ak1,
2277 a_block_slice_copy_step,
2278 b_grid_desc_bpreshuffled,
2279 b_block_desc_bk0_n_bk1,
2283 b_block_slice_copy_step,
2284 c_scale_thread_desc,
2286 a_scale_grid_desc_am_ak,
2287 a_scale_thread_desc,
2288 a_scale_thread_copy,
2290 a_scale_thread_slice_copy_step,
2291 b_scale_grid_desc_bn_ak,
2292 b_scale_thread_desc,
2293 b_scale_thread_copy,
2295 b_scale_thread_slice_copy_step,
2296 num_k_block_main_loop);
2302 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2303 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2306 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2310 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
2311 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2315 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
2316 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2318 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I0);
2319 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I1);
2320 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I2);
2321 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I3);
2322 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I4);
2323 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I5);
2324 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I6);
2325 constexpr
auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I7);
2327 static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
2328 static_assert(M0 * M1 * M2 == MPerBlock);
2329 static_assert(N4 == 4 || N4 == 8);
2336 if constexpr(MulRoutedWeight)
2338 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
2339 topk_weight = p_ds_grid[
I0][m_pos];
2344 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2347 if constexpr(IsInputGemm)
2351 float gate = c_thread_buf[cidx];
2352 float up = c_thread_buf_up[cidx];
2353 if constexpr(MulRoutedWeight)
2355 gate = gate * topk_weight;
2356 up = up * topk_weight;
2364 c_thread_buf(cidx) = gate * up;
2368 float gate = c_thread_buf[cidx];
2369 float up = c_thread_buf_up[cidx];
2370 if constexpr(MulRoutedWeight)
2372 gate = gate * topk_weight;
2373 up = up * topk_weight;
2381 c_thread_buf(cidx) = gate * up;
2386 if constexpr(MulRoutedWeight)
2388 c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
2397 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2400 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2401 static_cast<CShuffleDataType*
>(p_shared),
2402 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2405 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2425 const auto c_thread_mtx_on_block =
2426 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2428 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2429 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2431 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
2437 const auto m_thread_data_on_block_idx =
2438 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
2441 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
2447 const auto n_thread_data_on_block_idx =
2448 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
2452 auto c_thread_copy_vgpr_to_lds =
2455 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2456 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2458 Sequence<CShuffleMXdlPerWavePerShuffle,
2459 CShuffleNXdlPerWavePerShuffle,
2472 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2475 m_thread_data_on_block_idx[
I1],
2476 n_thread_data_on_block_idx[
I1],
2477 m_thread_data_on_block_idx[
I2],
2478 n_thread_data_on_block_idx[
I2],
2479 n_thread_data_on_block_idx[
I3],
2480 n_thread_data_on_block_idx[
I4]),
2483 using EDataType = CDataType;
2488 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2494 return make_dynamic_buffer<AddressSpaceEnum::Global>(
2495 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2501 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2503 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2508 tie(c_shuffle_block_buf),
2510 {
return ds_grid_buf[i]; },
2514 const auto idx_c_ds_block_begin =
2524 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2525 c_grid_desc_mblock_mperblock_nblock_nperblock;
2527 using CDEBlockTransferCluster =
2528 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2529 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2530 constexpr
index_t scatter_weight_idx = IsInputGemm ? 1 : 1;
2535 decltype(c_ds_desc_refs),
2536 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2537 CElementwiseOperation,
2541 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2543 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
2544 CDEBlockTransferCluster,
2550 CDEShuffleBlockTransferScalarPerVectors,
2562 idx_c_ds_block_begin,
2563 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2567 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2568 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2570 constexpr
auto sfc_c_vgpr =
2573 Sequence<CShuffleMXdlPerWavePerShuffle,
2574 CShuffleNXdlPerWavePerShuffle,
2582 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2585 constexpr
auto sfc_cde_block =
2589 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2591 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
2593 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2594 constexpr
auto EMThreads =
2595 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2596 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2597 constexpr
auto ENThreads =
2598 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2604 auto dstidx = sfc_cde_block.GetIndex(access_id);
2606 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2608 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2609 index_t token_offset = fused_token & 0xffffff;
2610 if constexpr(IsInputGemm)
2612 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2614 scatter_offsets(m0) = token_offset * problem.
N;
2620 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2621 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2623 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2624 c_shuffle_block_buf);
2630 cde_block_copy_lds_and_global.Run(
2633 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2637 if constexpr(access_id < num_access - 1)
2639 constexpr
auto cde_lds_and_global_step =
2640 sfc_cde_block.GetForwardStep(access_id);
2644 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2645 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2649 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2650 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2652 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:56
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
int64_t long_index_t
Definition: ck.hpp:299
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
Definition: gridwise_moe_gemm_blockscale.hpp:660
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm_blockscale.hpp:716
CDataType * p_c_grid
Definition: gridwise_moe_gemm_blockscale.hpp:722
const BScaleType * p_b_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:725
const index_t * p_max_token_id
Definition: gridwise_moe_gemm_blockscale.hpp:718
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm_blockscale.hpp:721
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:729
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm_blockscale.hpp:661
const ADataType * p_a_grid
Definition: gridwise_moe_gemm_blockscale.hpp:719
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:727
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm_blockscale.hpp:717
const BDataType * p_b_grid
Definition: gridwise_moe_gemm_blockscale.hpp:720
const AScaleType * p_a_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:724
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:728
Definition: gridwise_moe_gemm_blockscale.hpp:595
index_t K
Definition: gridwise_moe_gemm_blockscale.hpp:642
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm_blockscale.hpp:596
index_t TopK
Definition: gridwise_moe_gemm_blockscale.hpp:639
index_t NPadded
Definition: gridwise_moe_gemm_blockscale.hpp:649
index_t StrideB
Definition: gridwise_moe_gemm_blockscale.hpp:644
__host__ void Print() const
Definition: gridwise_moe_gemm_blockscale.hpp:627
index_t BK0
Definition: gridwise_moe_gemm_blockscale.hpp:653
index_t KRead
Definition: gridwise_moe_gemm_blockscale.hpp:650
index_t N
Definition: gridwise_moe_gemm_blockscale.hpp:641
index_t StrideC
Definition: gridwise_moe_gemm_blockscale.hpp:646
index_t KBatch
Definition: gridwise_moe_gemm_blockscale.hpp:647
index_t MBlock
Definition: gridwise_moe_gemm_blockscale.hpp:654
index_t KPadded
Definition: gridwise_moe_gemm_blockscale.hpp:651
index_t NumTokens
Definition: gridwise_moe_gemm_blockscale.hpp:638
index_t StrideA
Definition: gridwise_moe_gemm_blockscale.hpp:643
index_t AK0
Definition: gridwise_moe_gemm_blockscale.hpp:652
index_t M
Definition: gridwise_moe_gemm_blockscale.hpp:640
index_t MPadded
Definition: gridwise_moe_gemm_blockscale.hpp:648
index_t NBlock
Definition: gridwise_moe_gemm_blockscale.hpp:655
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm_blockscale.hpp:645
Definition: gridwise_moe_gemm_blockscale.hpp:733
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm_blockscale.hpp:734
index_t a_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:765
index_t b_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:766
Definition: gridwise_moe_gemm_blockscale.hpp:171
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:524
static constexpr index_t KPack
Definition: gridwise_moe_gemm_blockscale.hpp:196
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm_blockscale.hpp:887
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm_blockscale.hpp:411
static constexpr auto AK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:189
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:269
static constexpr auto BK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:190
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm_blockscale.hpp:230
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1917
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1133
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm_blockscale.hpp:769
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:260
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:285
static constexpr auto I3
Definition: gridwise_moe_gemm_blockscale.hpp:178
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:960
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:315
static constexpr index_t KLane
Definition: gridwise_moe_gemm_blockscale.hpp:209
remove_cvref_t< decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm_blockscale.hpp:936
float AScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:172
static constexpr auto AK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:187
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm_blockscale.hpp:211
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm_blockscale.hpp:329
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:1148
static constexpr auto I2
Definition: gridwise_moe_gemm_blockscale.hpp:177
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:232
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:265
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:274
static constexpr auto I4
Definition: gridwise_moe_gemm_blockscale.hpp:179
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm_blockscale.hpp:569
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:255
static constexpr auto I6
Definition: gridwise_moe_gemm_blockscale.hpp:181
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm_blockscale.hpp:184
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1140
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:518
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:246
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm_blockscale.hpp:228
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm_blockscale.hpp:894
float BScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:173
static constexpr auto I7
Definition: gridwise_moe_gemm_blockscale.hpp:182
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:304
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:581
static constexpr index_t NWave
Definition: gridwise_moe_gemm_blockscale.hpp:213
static constexpr index_t NLane
Definition: gridwise_moe_gemm_blockscale.hpp:212
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm_blockscale.hpp:193
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:297
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:309
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm_blockscale.hpp:938
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:291
static constexpr auto I0
Definition: gridwise_moe_gemm_blockscale.hpp:175
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:509
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:548
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm_blockscale.hpp:191
static constexpr index_t KGroup
Definition: gridwise_moe_gemm_blockscale.hpp:198
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:239
static constexpr auto I5
Definition: gridwise_moe_gemm_blockscale.hpp:180
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1169
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm_blockscale.hpp:217
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm_blockscale.hpp:421
static constexpr auto I1
Definition: gridwise_moe_gemm_blockscale.hpp:176
static constexpr auto BK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:188
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm_blockscale.hpp:215
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:279
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))> DsGridDesc_M_N
Definition: gridwise_moe_gemm_blockscale.hpp:592
Definition: xdlops_gemm.hpp:1126
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1700
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1647
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1694
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:440
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049