36 template <
typename GridwiseGemm,
37 bool HasMainKBlockLoop,
42 #if CK_USE_LAUNCH_BOUNDS
48 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
49 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
51 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
53 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
55 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56 karg.p_sorted_token_ids,
57 karg.p_sorted_expert_ids,
59 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
76 template <
typename GridwiseGemm,
77 bool HasMainKBlockLoop,
82 #if CK_USE_LAUNCH_BOUNDS
88 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
89 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
91 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
92 __shared__
char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
94 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
96 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
97 karg.p_sorted_token_ids,
98 karg.p_sorted_expert_ids,
100 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
101 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
118 template <
typename ALayout,
124 typename AccDataType,
125 typename CShuffleDataType,
128 typename AElementwiseOperation,
129 typename BElementwiseOperation,
130 typename CElementwiseOperation,
145 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
146 typename ABlockTransferThreadClusterArrangeOrder,
147 typename ABlockTransferSrcAccessOrder,
148 index_t ABlockTransferSrcVectorDim,
149 index_t ABlockTransferSrcScalarPerVector,
150 index_t ABlockTransferDstScalarPerVector_AK1,
151 bool AThreadTransferSrcResetCoordinateAfterRun,
153 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
154 typename BBlockTransferThreadClusterArrangeOrder,
155 typename BBlockTransferSrcAccessOrder,
156 index_t BBlockTransferSrcVectorDim,
157 index_t BBlockTransferSrcScalarPerVector,
158 index_t BBlockTransferDstScalarPerVector_BK1,
159 bool BThreadTransferSrcResetCoordinateAfterRun,
161 index_t CShuffleMXdlPerWavePerShuffle,
162 index_t CShuffleNXdlPerWavePerShuffle,
163 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
164 typename CDEShuffleBlockTransferScalarPerVectors,
167 index_t ActivationOperation = 0,
168 bool NSwizzle =
false,
169 bool IsInputGemm =
true,
170 bool MulRoutedWeight =
true,
172 typename ComputeTypeA = CDataType,
173 typename ComputeTypeB = ComputeTypeA,
174 typename LDSTypeA = ADataType,
175 typename LDSTypeB = BDataType>
191 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
229 return static_cast<const DDataType*
>(
nullptr);
256 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
257 const index_t gridy = NSwizzle ? 1 : mblock;
287 auto K_t = K_Batch * KPerBlock;
288 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
293 auto K_t = K_Batch * KPerBlock;
294 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
299 auto K_t = K_Batch * KPerBlock;
300 return (K + K_t - 1) / K_t * KPerBlock;
306 auto K_t = K_Batch * KReadVec;
307 return (K + K_t - 1) / K_t * KReadVec;
320 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
336 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
338 const auto a_grid_desc_mraw_kraw = [&]() {
339 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
343 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
351 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
352 GemmSpec == GemmSpecialization::MNKPadding)
355 const auto a_grid_desc_m_k =
369 return a_grid_desc_ak0_m_ak1;
371 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
372 GemmSpec == GemmSpecialization::MNPadding)
376 a_grid_desc_mraw_kraw,
382 return a_grid_desc_ak0_m_ak1;
384 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
385 GemmSpec == GemmSpecialization::NKPadding)
389 a_grid_desc_mraw_kraw,
401 return a_grid_desc_ak0_m_ak1;
407 a_grid_desc_mraw_kraw,
413 return a_grid_desc_ak0_m_ak1;
419 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
420 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
424 make_tuple(
NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber,
I1));
430 const auto b_grid_desc_nraw_kraw = [&]() {
444 GemmSpec != GemmSpecialization::Default),
445 "pk_i4_t does not support padding");
447 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
448 GemmSpec == GemmSpecialization::MNKPadding)
451 const auto b_grid_desc_n_k =
465 return b_grid_desc_bk0_n_bk1;
467 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
468 GemmSpec == GemmSpecialization::MNPadding)
472 b_grid_desc_nraw_kraw,
478 return b_grid_desc_bk0_n_bk1;
480 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
481 GemmSpec == GemmSpecialization::MKPadding)
485 b_grid_desc_nraw_kraw,
497 return b_grid_desc_bk0_n_bk1;
503 b_grid_desc_nraw_kraw,
509 return b_grid_desc_bk0_n_bk1;
513 template <
typename ABlockDesc_AK0_M_AK1>
514 __host__ __device__
static constexpr
auto
517 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
519 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
522 template <
typename BBlockDesc_BK0_N_BK1>
523 __host__ __device__
static constexpr
auto
526 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
529 template <
typename ELayout>
531 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
533 const auto c_grid_desc_mraw_nraw = [&]() {
552 template <
typename DLayout>
553 __host__ __device__
static auto
556 const auto c_grid_desc_mraw_nraw = [&]() {
581 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
586 template <
typename DsGr
idDesc>
588 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
593 ds_grid_desc_m_n[i], MBlock, NBlock);
609 std::array<index_t, NumDTensor> StrideDs_,
635 std::cout <<
"problem {" <<
"NumTokens:" <<
NumTokens <<
", " <<
"TopK:" <<
TopK <<
", "
636 <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
639 <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", " <<
"AK0:" <<
AK0
640 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock <<
", "
641 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
668 const index_t* p_sorted_expert_ids_,
669 const index_t* p_max_token_id_,
670 const ADataType* p_a_grid_,
671 const BDataType* p_b_grid_,
672 std::array<const void*, NumDTensor> p_ds_grid_,
673 CDataType* p_c_grid_,
681 std::array<index_t, NumDTensor> StrideDs_,
686 AElementwiseOperation a_element_op_,
687 BElementwiseOperation b_element_op_,
688 CElementwiseOperation c_element_op_)
718 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
742 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
746 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
751 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
755 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
761 if(k_id < karg.
KBatch - 1)
777 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
778 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
780 if constexpr(ABlockLdsExtraM)
790 constexpr
auto a_lds_block_desc =
802 return a_lds_block_desc_permuted;
809 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
810 constexpr
auto M1 = MPerBlock / M0;
812 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
813 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
814 constexpr
auto KThreadRead = WaveSize / MPerXdl;
815 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
817 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(LDSTypeA) > 128)
819 : 128 / (
AK1Number * M0 *
sizeof(LDSTypeA));
820 constexpr
auto KThreadReadPerm =
821 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
822 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
826 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(LDSTypeA) > 128)
828 : ((128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA))) > M0
830 : 128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA)));
836 Number<kfold * M0 / mpair>{},
855 a_lds_block_desc_permuted,
877 a_lds_block_desc_unmerged,
880 Number<KThreadWrite / kfold / KThreadReadPerm>{},
889 return a_lds_block_desc_ak0_m_ak1;
902 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
904 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
911 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
929 ABlockTransferSrcScalarPerVector,
930 BBlockTransferSrcScalarPerVector,
952 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
955 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
958 constexpr
auto c_block_size =
959 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
962 c_block_size *
sizeof(CShuffleDataType));
970 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
971 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
972 "Invalid tuning param!");
980 if(!(karg.M % MPerBlock == 0))
983 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
984 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
998 if(!(karg.N % NPerBlock == 0))
1001 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
1002 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1016 auto K_t = karg.KBatch * KPerBlock;
1017 if(!(karg.K % K_t == 0))
1020 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1021 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1022 <<
", in function: " << __func__ << std::endl;
1031 auto K_t = karg.KBatch * KReadVec;
1033 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1041 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1044 std::cout <<
"Arg K (" << karg.K
1045 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1046 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1047 << __LINE__ <<
", in function: " << __func__ << std::endl;
1055 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1058 std::cout <<
"Arg M (" << karg.M
1059 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1060 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1061 << __LINE__ <<
", in function: " << __func__ << std::endl;
1070 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1073 std::cout <<
"Arg N (" << karg.N
1074 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1075 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1076 << __LINE__ <<
", in function: " << __func__ << std::endl;
1084 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1087 std::cout <<
"Arg K (" << karg.K
1088 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1089 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1090 << __LINE__ <<
", in function: " << __func__ << std::endl;
1102 std::cout <<
"Arg N (" << karg.N
1103 <<
") value is not a multiple of "
1104 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1106 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1117 std::cout <<
"Arg M (" << karg.M
1118 <<
") value is not a multiple of "
1119 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1121 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1130 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1132 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1143 const index_t num_loop = K / KPerBlock;
1145 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1150 const index_t num_loop = K / KPerBlock;
1152 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1155 template <
typename CGr
idDesc>
1157 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1166 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1174 template <
bool HasMainKBlockLoop,
1178 const index_t* p_sorted_expert_ids,
1179 const index_t* p_max_token_id,
1180 const ADataType* p_a_grid,
1181 const BDataType* p_b_grid,
1183 CDataType* p_c_grid,
1188 AElementwiseOperation a_element_op,
1189 BElementwiseOperation b_element_op,
1190 CElementwiseOperation c_element_op)
1202 const auto b_grid_desc_bpreshuffled =
1204 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1222 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1225 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1227 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1228 if(expert_block_id * MPerBlock >= max_token_id)
1231 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1232 const auto block_mn = [&]() -> std::pair<int, int> {
1233 if constexpr(NSwizzle)
1235 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1237 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1238 const index_t expert_swizzle =
1239 ecnt > 0 ? ecnt : 1;
1240 const index_t bid_new = blockIdx.x - prefix_block;
1241 const index_t nid = __builtin_amdgcn_readfirstlane(
1242 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1244 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1249 return {blockIdx.x, blockIdx.y};
1252 const index_t block_n_id = block_mn.first;
1253 const index_t block_m_id = block_mn.second;
1255 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1258 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1259 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1260 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1261 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1262 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1263 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1265 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
1269 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1270 index_t token_offset = fused_token & 0xffffff;
1271 if constexpr(!IsInputGemm)
1273 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1275 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
1278 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
1279 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1284 const index_t n_block_data_idx_on_grid =
1285 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1287 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1288 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1289 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1291 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1293 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1294 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1295 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1296 p_b_scale_grid + expert_id * expert_scale_stride,
1297 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1308 AElementwiseOperation,
1312 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1313 ABlockTransferThreadClusterArrangeOrder,
1316 decltype(a_grid_desc_ak0_m_ak1),
1317 decltype(a_block_desc_ak0_m_ak1),
1318 ABlockTransferSrcAccessOrder,
1320 ABlockTransferSrcVectorDim,
1322 ABlockTransferSrcScalarPerVector,
1323 ABlockTransferDstScalarPerVector_AK1,
1326 AThreadTransferSrcResetCoordinateAfterRun,
1330 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1333 a_block_desc_ak0_m_ak1,
1340 auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1341 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1346 decltype(b_grid_desc_bpreshuffled),
1347 decltype(b_block_desc_bk0_n_bk1),
1351 BBlockTransferSrcScalarPerVector,
1352 BThreadTransferSrcResetCoordinateAfterRun,
1353 true>(b_grid_desc_bpreshuffled,
1361 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1362 static_cast<LDSTypeA*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1368 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1370 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1371 decltype(c_thread_buf) c_thread_buf_up;
1373 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1374 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1377 constexpr
index_t ScaleSliceSizeM = MXdlPerWave;
1386 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1387 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1388 constexpr
index_t WaveSize = BlockSize / (MWaves * NWaves);
1400 const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
1402 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
1407 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
1408 index_t token_offset = fused_token & 0xffffff;
1409 if constexpr(!IsInputGemm)
1411 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1413 scale_gather_offsets(m0) =
1417 auto a_scale_thread_copy =
1420 decltype(a_scale_grid_desc_am_ak),
1421 decltype(a_scale_thread_desc),
1431 auto b_scale_thread_copy =
1434 decltype(b_scale_grid_desc_bn_ak),
1435 decltype(b_scale_thread_desc),
1442 b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1445 constexpr
auto a_scale_thread_slice_copy_step =
1447 constexpr
auto b_scale_thread_slice_copy_step =
make_multi_index(0, ScaleSliceSizeK);
1450 if constexpr(IsInputGemm)
1452 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
1453 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1455 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1459 decltype(b_grid_desc_bpreshuffled),
1460 decltype(b_block_desc_bk0_n_bk1),
1464 BBlockTransferSrcScalarPerVector,
1465 BThreadTransferSrcResetCoordinateAfterRun,
1466 true>(b_grid_desc_bpreshuffled,
1472 p_b_scale_grid + expert_scale_stride / 2 /
BPackedSize;
1473 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1474 p_b_scale_grid_up + expert_id * expert_scale_stride,
1475 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1476 auto b_scale_thread_copy_up =
1479 decltype(b_scale_grid_desc_bn_ak),
1480 decltype(b_scale_thread_desc),
1487 b_scale_grid_desc_bn_ak,
1490 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1491 a_grid_desc_ak0_m_ak1,
1492 a_block_desc_ak0_m_ak1,
1496 a_block_slice_copy_step,
1498 b_grid_desc_bpreshuffled,
1499 b_block_desc_bk0_n_bk1,
1501 b_blockwise_copy_up,
1505 b_block_slice_copy_step,
1507 c_scale_thread_desc,
1511 a_scale_grid_desc_am_ak,
1512 a_scale_thread_desc,
1513 a_scale_thread_copy,
1515 a_scale_thread_slice_copy_step,
1517 b_scale_grid_desc_bn_ak,
1518 b_scale_thread_desc,
1519 b_scale_thread_copy,
1520 b_scale_thread_copy_up,
1522 b_scale_grid_buf_up,
1523 b_scale_thread_slice_copy_step,
1525 num_k_block_main_loop);
1529 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1530 a_grid_desc_ak0_m_ak1,
1531 a_block_desc_ak0_m_ak1,
1535 a_block_slice_copy_step,
1537 b_grid_desc_bpreshuffled,
1538 b_block_desc_bk0_n_bk1,
1542 b_block_slice_copy_step,
1544 c_scale_thread_desc,
1547 a_scale_grid_desc_am_ak,
1548 a_scale_thread_desc,
1549 a_scale_thread_copy,
1551 a_scale_thread_slice_copy_step,
1553 b_scale_grid_desc_bn_ak,
1554 b_scale_thread_desc,
1555 b_scale_thread_copy,
1557 b_scale_thread_slice_copy_step,
1559 num_k_block_main_loop);
1564 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1565 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1568 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1572 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1573 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1577 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1578 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1580 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I0);
1581 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I1);
1582 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I2);
1583 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I3);
1584 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I4);
1585 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I5);
1586 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I6);
1587 constexpr
auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I7);
1589 static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
1590 static_assert(M0 * M1 * M2 == MPerBlock);
1591 static_assert(N4 == 4 || N4 == 8);
1598 if constexpr(MulRoutedWeight)
1600 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
1601 topk_weight = p_ds_grid[
I0][m_pos];
1606 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1609 if constexpr(IsInputGemm)
1613 float gate = c_thread_buf[cidx];
1614 float up = c_thread_buf_up[cidx];
1615 if constexpr(MulRoutedWeight)
1617 gate = gate * topk_weight;
1618 up = up * topk_weight;
1626 c_thread_buf(cidx) = gate * up;
1630 float gate = c_thread_buf[cidx];
1631 float up = c_thread_buf_up[cidx];
1632 if constexpr(MulRoutedWeight)
1634 gate = gate * topk_weight;
1635 up = up * topk_weight;
1643 c_thread_buf(cidx) = gate * up;
1648 if constexpr(MulRoutedWeight)
1650 c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
1658 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1661 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1662 static_cast<CShuffleDataType*
>(p_shared),
1663 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1666 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1686 const auto c_thread_mtx_on_block =
1687 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1689 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1690 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1692 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1698 const auto m_thread_data_on_block_idx =
1699 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1702 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1708 const auto n_thread_data_on_block_idx =
1709 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1713 auto c_thread_copy_vgpr_to_lds =
1716 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1717 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1719 Sequence<CShuffleMXdlPerWavePerShuffle,
1720 CShuffleNXdlPerWavePerShuffle,
1733 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1736 m_thread_data_on_block_idx[
I1],
1737 n_thread_data_on_block_idx[
I1],
1738 m_thread_data_on_block_idx[
I2],
1739 n_thread_data_on_block_idx[
I2],
1740 n_thread_data_on_block_idx[
I3],
1741 n_thread_data_on_block_idx[
I4]),
1744 using EDataType = CDataType;
1749 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1756 const DDataType* ptr_ = p_ds_grid[i];
1759 return make_dynamic_buffer<AddressSpaceEnum::Global>(
1760 ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
1766 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1768 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1773 tie(c_shuffle_block_buf),
1775 {
return ds_grid_buf[i]; },
1779 const auto idx_c_ds_block_begin =
1789 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1790 c_grid_desc_mblock_mperblock_nblock_nperblock;
1792 using CDEBlockTransferCluster =
1793 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1794 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1795 constexpr
index_t scatter_weight_idx = IsInputGemm ? 1 : 1;
1800 decltype(c_ds_desc_refs),
1801 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1802 CElementwiseOperation,
1806 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1808 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
1809 CDEBlockTransferCluster,
1815 CDEShuffleBlockTransferScalarPerVectors,
1827 idx_c_ds_block_begin,
1828 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1832 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1833 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1835 constexpr
auto sfc_c_vgpr =
1838 Sequence<CShuffleMXdlPerWavePerShuffle,
1839 CShuffleNXdlPerWavePerShuffle,
1847 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1850 constexpr
auto sfc_cde_block =
1854 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1856 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
1858 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
1859 constexpr
auto EMThreads =
1860 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
1861 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1862 constexpr
auto ENThreads =
1863 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
1868 auto dstidx = sfc_cde_block.GetIndex(access_id);
1870 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
1872 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1873 index_t token_offset = fused_token & 0xffffff;
1874 if constexpr(IsInputGemm)
1876 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1878 scatter_offsets(m0) = token_offset * problem.
N;
1884 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1885 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1887 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1888 c_shuffle_block_buf);
1894 cde_block_copy_lds_and_global.Run(
1897 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1901 if constexpr(access_id < num_access - 1)
1903 constexpr
auto cde_lds_and_global_step =
1904 sfc_cde_block.GetForwardStep(access_id);
1908 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1909 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
1913 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1914 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1916 cde_lds_and_global_step);
1922 template <
bool HasMainKBlockLoop,
1926 const index_t* p_sorted_expert_ids,
1927 const index_t* p_max_token_id,
1928 const ADataType* p_a_grid,
1929 const BDataType* p_b_grid,
1931 CDataType* p_c_grid,
1937 AElementwiseOperation a_element_op,
1938 BElementwiseOperation b_element_op,
1939 CElementwiseOperation c_element_op)
1951 const auto b_grid_desc_bpreshuffled =
1953 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1970 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1973 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1974 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1975 if(expert_block_id * MPerBlock >= max_token_id)
1978 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1979 const auto block_mn = [&]() -> std::pair<int, int> {
1980 if constexpr(NSwizzle)
1982 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1984 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1985 const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
1986 const index_t bid_new = blockIdx.x - prefix_block;
1987 const index_t nid = __builtin_amdgcn_readfirstlane(
1988 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1990 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1995 return {blockIdx.x, blockIdx.y};
1998 const index_t block_n_id = block_mn.first;
1999 const index_t block_m_id = block_mn.second;
2002 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2005 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
2006 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
2007 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
2008 constexpr
auto AKThreads = AK0Threads * AK1Threads;
2009 constexpr
auto AMRepeats = MPerBlock / AMThreads;
2010 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2012 if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
2018 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2019 index_t token_offset = fused_token & 0xffffff;
2020 if constexpr(!IsInputGemm)
2022 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2024 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
2027 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
2028 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2032 const index_t n_block_data_idx_on_grid =
2033 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2035 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2036 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2037 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2039 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2041 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2042 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2043 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2044 p_b_scale_grid + expert_id * expert_scale_stride,
2045 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2056 AElementwiseOperation,
2060 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2061 ABlockTransferThreadClusterArrangeOrder,
2064 decltype(a_grid_desc_ak0_m_ak1),
2065 decltype(a_block_desc_ak0_m_ak1),
2066 ABlockTransferSrcAccessOrder,
2068 ABlockTransferSrcVectorDim,
2070 ABlockTransferSrcScalarPerVector,
2071 ABlockTransferDstScalarPerVector_AK1,
2074 AThreadTransferSrcResetCoordinateAfterRun,
2078 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2081 a_block_desc_ak0_m_ak1,
2088 auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2089 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2090 auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2091 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2092 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2097 decltype(b_grid_desc_bpreshuffled),
2098 decltype(b_block_desc_bk0_n_bk1),
2102 BBlockTransferSrcScalarPerVector,
2103 BThreadTransferSrcResetCoordinateAfterRun,
2104 true>(b_grid_desc_bpreshuffled,
2112 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2113 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2114 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2115 static_cast<ADataType*
>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2116 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2122 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2124 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2125 decltype(c_thread_buf) c_thread_buf_up;
2127 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2128 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2132 constexpr
index_t ScaleSliceSizeM = MXdlPerWave;
2141 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
2142 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
2143 constexpr
index_t WaveSize = BlockSize / (MWaves * NWaves);
2155 const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
2157 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
2162 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
2163 index_t token_offset = fused_token & 0xffffff;
2164 if constexpr(!IsInputGemm)
2166 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2168 scale_gather_offsets(m0) =
static_cast<IndexType
>(token_offset) *
2172 auto a_scale_thread_copy =
2175 decltype(a_scale_grid_desc_am_ak),
2176 decltype(a_scale_thread_desc),
2186 auto b_scale_thread_copy =
2189 decltype(b_scale_grid_desc_bn_ak),
2190 decltype(b_scale_thread_desc),
2197 b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2200 constexpr
auto a_scale_thread_slice_copy_step =
2202 constexpr
auto b_scale_thread_slice_copy_step =
make_multi_index(0, ScaleSliceSizeK);
2205 if constexpr(IsInputGemm)
2207 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
2208 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2210 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2214 decltype(b_grid_desc_bpreshuffled),
2215 decltype(b_block_desc_bk0_n_bk1),
2219 BBlockTransferSrcScalarPerVector,
2220 BThreadTransferSrcResetCoordinateAfterRun,
2221 true>(b_grid_desc_bpreshuffled,
2227 p_b_scale_grid + expert_scale_stride / 2 /
BPackedSize;
2228 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2229 p_b_scale_grid_up + expert_id * expert_scale_stride /
BPackedSize,
2230 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2231 auto b_scale_thread_copy_up =
2234 decltype(b_scale_grid_desc_bn_ak),
2235 decltype(b_scale_thread_desc),
2242 b_scale_grid_desc_bn_ak,
2245 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2246 a_grid_desc_ak0_m_ak1,
2247 a_block_desc_ak0_m_ak1,
2251 a_block_slice_copy_step,
2252 b_grid_desc_bpreshuffled,
2253 b_block_desc_bk0_n_bk1,
2255 b_blockwise_copy_up,
2259 b_block_slice_copy_step,
2260 c_scale_thread_desc,
2263 a_scale_grid_desc_am_ak,
2264 a_scale_thread_desc,
2265 a_scale_thread_copy,
2267 a_scale_thread_slice_copy_step,
2268 b_scale_grid_desc_bn_ak,
2269 b_scale_thread_desc,
2270 b_scale_thread_copy,
2271 b_scale_thread_copy_up,
2273 b_scale_grid_buf_up,
2274 b_scale_thread_slice_copy_step,
2275 num_k_block_main_loop);
2279 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2280 a_grid_desc_ak0_m_ak1,
2281 a_block_desc_ak0_m_ak1,
2285 a_block_slice_copy_step,
2286 b_grid_desc_bpreshuffled,
2287 b_block_desc_bk0_n_bk1,
2291 b_block_slice_copy_step,
2292 c_scale_thread_desc,
2294 a_scale_grid_desc_am_ak,
2295 a_scale_thread_desc,
2296 a_scale_thread_copy,
2298 a_scale_thread_slice_copy_step,
2299 b_scale_grid_desc_bn_ak,
2300 b_scale_thread_desc,
2301 b_scale_thread_copy,
2303 b_scale_thread_slice_copy_step,
2304 num_k_block_main_loop);
2310 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2311 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2314 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2318 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
2319 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2323 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
2324 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2326 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I0);
2327 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I1);
2328 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I2);
2329 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I3);
2330 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I4);
2331 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I5);
2332 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I6);
2333 constexpr
auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I7);
2335 static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
2336 static_assert(M0 * M1 * M2 == MPerBlock);
2337 static_assert(N4 == 4 || N4 == 8);
2344 if constexpr(MulRoutedWeight)
2346 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
2347 topk_weight = p_ds_grid[
I0][m_pos];
2352 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2355 if constexpr(IsInputGemm)
2359 float gate = c_thread_buf[cidx];
2360 float up = c_thread_buf_up[cidx];
2361 if constexpr(MulRoutedWeight)
2363 gate = gate * topk_weight;
2364 up = up * topk_weight;
2372 c_thread_buf(cidx) = gate * up;
2376 float gate = c_thread_buf[cidx];
2377 float up = c_thread_buf_up[cidx];
2378 if constexpr(MulRoutedWeight)
2380 gate = gate * topk_weight;
2381 up = up * topk_weight;
2389 c_thread_buf(cidx) = gate * up;
2394 if constexpr(MulRoutedWeight)
2396 c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
2405 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2408 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2409 static_cast<CShuffleDataType*
>(p_shared),
2410 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2413 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2433 const auto c_thread_mtx_on_block =
2434 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2436 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2437 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2439 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
2445 const auto m_thread_data_on_block_idx =
2446 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
2449 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
2455 const auto n_thread_data_on_block_idx =
2456 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
2460 auto c_thread_copy_vgpr_to_lds =
2463 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2464 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2466 Sequence<CShuffleMXdlPerWavePerShuffle,
2467 CShuffleNXdlPerWavePerShuffle,
2480 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2483 m_thread_data_on_block_idx[
I1],
2484 n_thread_data_on_block_idx[
I1],
2485 m_thread_data_on_block_idx[
I2],
2486 n_thread_data_on_block_idx[
I2],
2487 n_thread_data_on_block_idx[
I3],
2488 n_thread_data_on_block_idx[
I4]),
2491 using EDataType = CDataType;
2496 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2502 return make_dynamic_buffer<AddressSpaceEnum::Global>(
2503 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2509 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2511 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2516 tie(c_shuffle_block_buf),
2518 {
return ds_grid_buf[i]; },
2522 const auto idx_c_ds_block_begin =
2532 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2533 c_grid_desc_mblock_mperblock_nblock_nperblock;
2535 using CDEBlockTransferCluster =
2536 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2537 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2538 constexpr
index_t scatter_weight_idx = IsInputGemm ? 1 : 1;
2543 decltype(c_ds_desc_refs),
2544 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2545 CElementwiseOperation,
2549 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2551 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
2552 CDEBlockTransferCluster,
2558 CDEShuffleBlockTransferScalarPerVectors,
2570 idx_c_ds_block_begin,
2571 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2575 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2576 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2578 constexpr
auto sfc_c_vgpr =
2581 Sequence<CShuffleMXdlPerWavePerShuffle,
2582 CShuffleNXdlPerWavePerShuffle,
2590 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2593 constexpr
auto sfc_cde_block =
2597 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2599 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
2601 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2602 constexpr
auto EMThreads =
2603 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2604 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2605 constexpr
auto ENThreads =
2606 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2612 auto dstidx = sfc_cde_block.GetIndex(access_id);
2614 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2616 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2617 index_t token_offset = fused_token & 0xffffff;
2618 if constexpr(IsInputGemm)
2620 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2622 scatter_offsets(m0) = token_offset * problem.
N;
2628 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2629 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2631 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2632 c_shuffle_block_buf);
2638 cde_block_copy_lds_and_global.Run(
2641 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2645 if constexpr(access_id < num_access - 1)
2647 constexpr
auto cde_lds_and_global_step =
2648 sfc_cde_block.GetForwardStep(access_id);
2652 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2653 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2657 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2658 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2660 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
int64_t long_index_t
Definition: ck.hpp:300
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
Definition: gridwise_moe_gemm_blockscale.hpp:666
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm_blockscale.hpp:722
CDataType * p_c_grid
Definition: gridwise_moe_gemm_blockscale.hpp:728
const BScaleType * p_b_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:731
const index_t * p_max_token_id
Definition: gridwise_moe_gemm_blockscale.hpp:724
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm_blockscale.hpp:727
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:735
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm_blockscale.hpp:667
const ADataType * p_a_grid
Definition: gridwise_moe_gemm_blockscale.hpp:725
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:733
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm_blockscale.hpp:723
const BDataType * p_b_grid
Definition: gridwise_moe_gemm_blockscale.hpp:726
const AScaleType * p_a_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:730
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:734
Definition: gridwise_moe_gemm_blockscale.hpp:601
index_t K
Definition: gridwise_moe_gemm_blockscale.hpp:648
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm_blockscale.hpp:602
index_t TopK
Definition: gridwise_moe_gemm_blockscale.hpp:645
index_t NPadded
Definition: gridwise_moe_gemm_blockscale.hpp:655
index_t StrideB
Definition: gridwise_moe_gemm_blockscale.hpp:650
__host__ void Print() const
Definition: gridwise_moe_gemm_blockscale.hpp:633
index_t BK0
Definition: gridwise_moe_gemm_blockscale.hpp:659
index_t KRead
Definition: gridwise_moe_gemm_blockscale.hpp:656
index_t N
Definition: gridwise_moe_gemm_blockscale.hpp:647
index_t StrideC
Definition: gridwise_moe_gemm_blockscale.hpp:652
index_t KBatch
Definition: gridwise_moe_gemm_blockscale.hpp:653
index_t MBlock
Definition: gridwise_moe_gemm_blockscale.hpp:660
index_t KPadded
Definition: gridwise_moe_gemm_blockscale.hpp:657
index_t NumTokens
Definition: gridwise_moe_gemm_blockscale.hpp:644
index_t StrideA
Definition: gridwise_moe_gemm_blockscale.hpp:649
index_t AK0
Definition: gridwise_moe_gemm_blockscale.hpp:658
index_t M
Definition: gridwise_moe_gemm_blockscale.hpp:646
index_t MPadded
Definition: gridwise_moe_gemm_blockscale.hpp:654
index_t NBlock
Definition: gridwise_moe_gemm_blockscale.hpp:661
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm_blockscale.hpp:651
Definition: gridwise_moe_gemm_blockscale.hpp:739
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm_blockscale.hpp:740
index_t a_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:771
index_t b_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:772
Definition: gridwise_moe_gemm_blockscale.hpp:177
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:530
static constexpr index_t KPack
Definition: gridwise_moe_gemm_blockscale.hpp:202
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm_blockscale.hpp:893
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm_blockscale.hpp:417
static constexpr auto AK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:195
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:275
static constexpr auto BK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:196
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm_blockscale.hpp:236
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1925
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1141
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm_blockscale.hpp:775
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:266
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:291
static constexpr auto I3
Definition: gridwise_moe_gemm_blockscale.hpp:184
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:968
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:321
static constexpr index_t KLane
Definition: gridwise_moe_gemm_blockscale.hpp:215
remove_cvref_t< decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm_blockscale.hpp:942
float AScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:178
static constexpr auto AK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:193
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm_blockscale.hpp:217
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm_blockscale.hpp:335
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:1156
static constexpr auto I2
Definition: gridwise_moe_gemm_blockscale.hpp:183
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:238
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:271
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:280
static constexpr auto I4
Definition: gridwise_moe_gemm_blockscale.hpp:185
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm_blockscale.hpp:575
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:261
static constexpr auto I6
Definition: gridwise_moe_gemm_blockscale.hpp:187
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm_blockscale.hpp:190
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1148
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:524
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:252
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm_blockscale.hpp:234
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm_blockscale.hpp:900
float BScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:179
static constexpr auto I7
Definition: gridwise_moe_gemm_blockscale.hpp:188
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:310
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:587
static constexpr index_t NWave
Definition: gridwise_moe_gemm_blockscale.hpp:219
static constexpr index_t NLane
Definition: gridwise_moe_gemm_blockscale.hpp:218
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm_blockscale.hpp:199
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:303
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:315
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm_blockscale.hpp:944
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:297
static constexpr auto I0
Definition: gridwise_moe_gemm_blockscale.hpp:181
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:515
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:554
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm_blockscale.hpp:197
static constexpr index_t KGroup
Definition: gridwise_moe_gemm_blockscale.hpp:204
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:245
static constexpr auto I5
Definition: gridwise_moe_gemm_blockscale.hpp:186
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1177
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm_blockscale.hpp:223
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm_blockscale.hpp:427
static constexpr auto I1
Definition: gridwise_moe_gemm_blockscale.hpp:182
static constexpr auto BK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:194
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm_blockscale.hpp:221
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:285
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))> DsGridDesc_M_N
Definition: gridwise_moe_gemm_blockscale.hpp:598
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1810
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1804
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:440
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: amd_ck_fp8.hpp:36
Definition: integral_constant.hpp:20
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049