38 template <
typename GridwiseGemm,
39 bool HasMainKBlockLoop,
44 #if CK_USE_LAUNCH_BOUNDS
51 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
53 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
55 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56 karg.p_sorted_token_ids,
57 karg.p_sorted_expert_ids,
59 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60 karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
61 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
62 karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
76 template <
typename GridwiseGemm,
77 bool HasMainKBlockLoop,
82 #if CK_USE_LAUNCH_BOUNDS
89 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
90 __shared__
char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
94 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
95 karg.p_sorted_token_ids,
96 karg.p_sorted_expert_ids,
116 template <
typename ALayout,
121 typename AScaleDataType,
123 typename BScaleDataType,
124 typename AccDataType,
125 typename CShuffleDataType,
128 typename AElementwiseOperation,
129 typename BElementwiseOperation,
130 typename CElementwiseOperation,
143 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
144 typename ABlockTransferThreadClusterArrangeOrder,
145 typename ABlockTransferSrcAccessOrder,
146 index_t ABlockTransferSrcVectorDim,
147 index_t ABlockTransferSrcScalarPerVector,
148 index_t ABlockTransferDstScalarPerVector_AK1,
149 bool AThreadTransferSrcResetCoordinateAfterRun,
151 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
152 typename BBlockTransferThreadClusterArrangeOrder,
153 typename BBlockTransferSrcAccessOrder,
154 index_t BBlockTransferSrcVectorDim,
155 index_t BBlockTransferSrcScalarPerVector,
156 index_t BBlockTransferDstScalarPerVector_BK1,
157 bool BThreadTransferSrcResetCoordinateAfterRun,
159 index_t CShuffleMXdlPerWavePerShuffle,
160 index_t CShuffleNXdlPerWavePerShuffle,
161 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162 typename CDEShuffleBlockTransferScalarPerVectors,
165 index_t ActivationOperation = 0,
166 bool NSwizzle =
false,
167 bool IsInputGemm =
true,
168 bool MulRoutedWeight =
true,
170 typename ComputeTypeA = ADataType,
171 typename ComputeTypeB = BDataType>
189 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
225 return static_cast<const DDataType*
>(
nullptr);
238 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
239 const index_t gridy = NSwizzle ? 1 : mblock;
261 auto K_t = K_Batch * KPerBlock;
262 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
267 auto K_t = K_Batch * KPerBlock;
268 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
273 auto K_t = K_Batch * KPerBlock;
274 return (K + K_t - 1) / K_t * KPerBlock;
280 auto K_t = K_Batch * KReadVec;
281 return (K + K_t - 1) / K_t * KReadVec;
294 template <
index_t MNXdlPerWave,
298 typename TileDesc_K0_MN_K1>
316 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
318 const auto a_grid_desc_mraw_kraw = [&]() {
319 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
323 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
331 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
332 GemmSpec == GemmSpecialization::MNKPadding)
335 const auto a_grid_desc_m_k =
349 return a_grid_desc_ak0_m_ak1;
351 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
352 GemmSpec == GemmSpecialization::MNPadding)
356 a_grid_desc_mraw_kraw,
362 return a_grid_desc_ak0_m_ak1;
364 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
365 GemmSpec == GemmSpecialization::NKPadding)
369 a_grid_desc_mraw_kraw,
381 return a_grid_desc_ak0_m_ak1;
387 a_grid_desc_mraw_kraw,
393 return a_grid_desc_ak0_m_ak1;
400 const auto b_grid_desc_nraw_kraw = [&]() {
414 GemmSpec != GemmSpecialization::Default),
415 "pk_i4_t does not support padding");
417 GemmSpec != GemmSpecialization::Default),
418 "f4x2_pk_t does not support padding");
420 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
421 GemmSpec == GemmSpecialization::MNKPadding)
424 const auto b_grid_desc_n_k =
438 return b_grid_desc_bk0_n_bk1;
440 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
441 GemmSpec == GemmSpecialization::MNPadding)
445 b_grid_desc_nraw_kraw,
451 return b_grid_desc_bk0_n_bk1;
453 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
454 GemmSpec == GemmSpecialization::MKPadding)
458 b_grid_desc_nraw_kraw,
470 return b_grid_desc_bk0_n_bk1;
476 b_grid_desc_nraw_kraw,
482 return b_grid_desc_bk0_n_bk1;
486 template <
typename ABlockDesc_AK0_M_AK1>
487 __host__ __device__
static constexpr
auto
490 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
492 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl>(
493 ABlockDesc_AK0_M_AK1{});
496 template <
typename BBlockDesc_BK0_N_BK1>
497 __host__ __device__
static constexpr
auto
500 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
502 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl>(
503 BBlockDesc_BK0_N_BK1{});
506 template <
typename ELayout>
508 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
510 const auto c_grid_desc_mraw_nraw = [&]() {
529 template <
typename DLayout>
530 __host__ __device__
static auto
533 const auto c_grid_desc_mraw_nraw = [&]() {
558 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
563 template <
typename DsGr
idDesc>
565 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
570 ds_grid_desc_m_n[i], MBlock, NBlock);
586 std::array<index_t, NumDTensor> StrideDs_,
614 std::cout <<
"problem {" <<
"NumTokens:" <<
NumTokens <<
", " <<
"TopK:" <<
TopK <<
", "
615 <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
619 <<
", " <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", "
620 <<
"AK0:" <<
AK0 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock
621 <<
", " <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
650 const index_t* p_sorted_expert_ids_,
651 const index_t* p_max_token_id_,
652 const ADataType* p_a_grid_,
653 const AScaleDataType* p_a_scale_grid_,
654 const BDataType* p_b_grid_,
655 const BScaleDataType* p_b_scale_grid_,
656 std::array<const void*, NumDTensor> p_ds_grid_,
657 CDataType* p_c_grid_,
667 std::array<index_t, NumDTensor> StrideDs_,
670 AElementwiseOperation a_element_op_,
671 BElementwiseOperation b_element_op_,
672 CElementwiseOperation c_element_op_)
704 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
727 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
731 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
736 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
740 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
747 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
751 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
758 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
763 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
768 if(k_id < karg.
KBatch - 1)
786 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
787 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
788 constexpr
index_t WaveSize = BlockSize / (MWave * NWave);
801 constexpr
auto a_lds_block_desc =
813 return a_lds_block_desc_permuted;
820 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
821 constexpr
auto M1 = MPerBlock / M0;
823 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
824 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
825 constexpr
auto KThreadRead = WaveSize / MPerXdl;
826 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
828 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
830 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
831 constexpr
auto KThreadReadPerm =
832 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
833 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
837 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
839 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
841 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
847 Number<kfold * M0 / mpair>{},
866 a_lds_block_desc_permuted,
888 a_lds_block_desc_unmerged,
891 Number<KThreadWrite / kfold / KThreadReadPerm>{},
900 return a_lds_block_desc_ak0_m_ak1;
906 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
907 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
908 constexpr
index_t WaveSize = BlockSize / (MWave * NWave);
920 constexpr
auto b_lds_block_desc =
932 return b_lds_block_desc_permuted;
936 constexpr
auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
937 constexpr
auto N1 = NPerBlock / N0;
939 constexpr
auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
940 constexpr
auto K0PerThreadWrite =
BK0Number / KThreadWrite;
941 constexpr
auto KThreadRead = WaveSize / NPerXdl;
942 constexpr
auto K0PerThreadRead =
BK0Number / KThreadRead;
944 constexpr
auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
946 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
947 constexpr
auto KThreadReadPerm =
948 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
949 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
953 constexpr
auto npair = (
BK1Number * NPerXdl *
sizeof(BDataType) > 128)
955 : ((128 / (
BK1Number * NPerXdl *
sizeof(BDataType))) > N0
957 : 128 / (
BK1Number * NPerXdl *
sizeof(BDataType)));
963 Number<kfold * N0 / npair>{},
982 b_lds_block_desc_permuted,
1004 b_lds_block_desc_unmerged,
1007 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1016 return b_lds_block_desc_bk0_n_bk1;
1022 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1023 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1025 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1032 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1053 ABlockTransferSrcScalarPerVector,
1054 BBlockTransferSrcScalarPerVector,
1075 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1078 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1081 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1084 constexpr
auto c_block_size =
1085 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1087 if constexpr(IsInputGemm)
1089 return math::max((a_block_space_size_aligned *
sizeof(ADataType) +
1090 b_block_space_size_aligned *
sizeof(BDataType)) *
1092 c_block_size *
sizeof(CShuffleDataType));
1096 return math::max((a_block_space_size_aligned *
sizeof(ADataType) +
1097 b_block_space_size_aligned *
sizeof(BDataType)),
1098 c_block_size *
sizeof(CShuffleDataType));
1105 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1106 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1107 "Invalid tuning param!");
1109 static_assert(KPerBlock % (ScaleBlockSize /
BPackedSize) == 0,
1110 "KPerBlock should be multiple of ScaleBlockSize");
1118 if(!(karg.
M % MPerBlock == 0))
1122 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M <<
" "
1123 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1136 if(!(karg.
N % NPerBlock == 0))
1140 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N <<
" "
1141 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1153 auto K_t = karg.
KBatch * KPerBlock;
1154 if(!(karg.
K % K_t == 0))
1158 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1159 << karg.
K <<
" " << __FILE__ <<
":" << __LINE__
1160 <<
", in function: " << __func__ << std::endl;
1168 auto K_t = karg.
KBatch * KReadVec;
1170 if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
1178 if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
1182 std::cout <<
"Arg K (" << karg.
K
1183 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1184 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1185 << __LINE__ <<
", in function: " << __func__ << std::endl;
1192 if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
1196 std::cout <<
"Arg M (" << karg.
M
1197 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1198 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1199 << __LINE__ <<
", in function: " << __func__ << std::endl;
1207 if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
1211 std::cout <<
"Arg N (" << karg.
N
1212 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1213 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1214 << __LINE__ <<
", in function: " << __func__ << std::endl;
1221 if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
1225 std::cout <<
"Arg K (" << karg.
K
1226 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1227 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1228 << __LINE__ <<
", in function: " << __func__ << std::endl;
1240 std::cout <<
"Arg N (" << karg.
N
1241 <<
") value is not a multiple of "
1242 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1244 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1256 std::cout <<
"Arg M (" << karg.
M
1257 <<
") value is not a multiple of "
1258 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1260 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1270 const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
1272 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1283 const index_t num_loop = K / KPerBlock;
1285 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1290 const index_t num_loop = K / KPerBlock;
1292 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1295 template <
typename CGr
idDesc>
1297 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1306 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1318 "A scale pack data type too large!");
1320 "B scale pack data type too large!");
1322 template <
bool HasMainKBlockLoop,
1326 const index_t* p_sorted_expert_ids,
1327 const index_t* p_max_token_id,
1328 const ADataType* p_a_grid,
1329 const AScaleDataType* p_a_scale_grid,
1330 const BDataType* p_b_grid,
1331 const BScaleDataType* p_b_scale_grid,
1333 CDataType* p_c_grid,
1336 AElementwiseOperation a_element_op,
1337 BElementwiseOperation b_element_op,
1338 CElementwiseOperation c_element_op)
1350 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1369 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1373 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1374 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1375 if(expert_block_id * MPerBlock >= max_token_id)
1378 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1380 const auto block_mn = [&]() -> std::pair<int, int> {
1381 if constexpr(NSwizzle)
1383 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1385 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1386 const index_t expert_swizzle =
1387 ecnt > 0 ? ecnt : 1;
1388 const index_t bid_new = blockIdx.x - prefix_block;
1389 const index_t nid = __builtin_amdgcn_readfirstlane(
1390 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1392 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1397 return {blockIdx.x, blockIdx.y};
1401 const index_t block_n_id = block_mn.first;
1402 const index_t block_m_id = block_mn.second;
1404 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1407 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1408 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1409 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1410 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1411 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1412 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1414 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
1418 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1419 index_t token_offset = fused_token & 0xffffff;
1420 if constexpr(!IsInputGemm)
1422 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1424 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
1428 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
1429 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1430 problem.
N * (IsInputGemm ? 2 : 1) *
1434 const index_t n_block_data_idx_on_grid =
1435 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1438 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1439 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1440 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1441 p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1444 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1445 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1446 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1447 p_b_scale_grid + (expert_id * expert_scale_stride) /
sizeof(BScaleDataType),
1448 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1462 AElementwiseOperation,
1466 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1467 ABlockTransferThreadClusterArrangeOrder,
1470 decltype(a_grid_desc_ak0_m_ak1),
1471 decltype(a_block_desc_ak0_m_ak1),
1472 ABlockTransferSrcAccessOrder,
1474 ABlockTransferSrcVectorDim,
1476 ABlockTransferSrcScalarPerVector,
1477 ABlockTransferDstScalarPerVector_AK1,
1480 AThreadTransferSrcResetCoordinateAfterRun,
1484 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1487 a_block_desc_ak0_m_ak1,
1493 auto b_blockwise_copy =
1495 BElementwiseOperation,
1499 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1500 BBlockTransferThreadClusterArrangeOrder,
1503 decltype(b_grid_desc_bk0_n_bk1),
1504 decltype(b_block_desc_bk0_n_bk1),
1505 BBlockTransferSrcAccessOrder,
1507 BBlockTransferSrcVectorDim,
1509 BBlockTransferSrcScalarPerVector,
1510 BBlockTransferDstScalarPerVector_BK1,
1513 BThreadTransferSrcResetCoordinateAfterRun,
1515 BlockwiseGemmPipe::GlobalBufferNum>(
1516 b_grid_desc_bk0_n_bk1,
1519 b_block_desc_bk0_n_bk1,
1525 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1528 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1529 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1531 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1532 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) +
1533 a_block_space_size_aligned *
sizeof(ADataType)),
1534 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1540 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1542 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1543 decltype(c_thread_buf) c_thread_buf_up;
1547 c_thread_buf.num_of_v_,
1548 c_thread_buf.s_per_v,
1552 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1553 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1557 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1558 const auto waveId_m = wave_idx[
I0];
1559 const auto waveId_n = wave_idx[
I1];
1561 auto thread_offset_shuffled =
1564 auto a_thread_offset_m = waveId_m;
1569 decltype(a_scale_grid_desc_am_ak),
1570 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1576 true>(a_scale_grid_desc_am_ak,
1582 auto b_thread_offset_n = waveId_n;
1587 decltype(b_scale_grid_desc_bn_ak),
1588 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1594 true>(b_scale_grid_desc_bn_ak,
1599 if constexpr(IsInputGemm)
1602 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1603 auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1604 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) +
1605 a_block_space_size_aligned *
sizeof(ADataType) +
1606 b_block_space_size_aligned *
sizeof(BDataType)),
1607 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1609 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1610 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1611 p_b_grid_up + expert_id * expert_stride,
1612 b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1614 auto b_blockwise_copy_up =
1616 BElementwiseOperation,
1620 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1621 BBlockTransferThreadClusterArrangeOrder,
1624 decltype(b_grid_desc_bk0_n_bk1),
1625 decltype(b_block_desc_bk0_n_bk1),
1626 BBlockTransferSrcAccessOrder,
1628 BBlockTransferSrcVectorDim,
1630 BBlockTransferSrcScalarPerVector,
1631 BBlockTransferDstScalarPerVector_BK1,
1634 BThreadTransferSrcResetCoordinateAfterRun,
1636 BlockwiseGemmPipe::GlobalBufferNum>(
1637 b_grid_desc_bk0_n_bk1,
1640 b_block_desc_bk0_n_bk1,
1644 const BScaleDataType* p_b_scale_grid_up =
1645 p_b_scale_grid + expert_scale_stride / 2 /
sizeof(BScaleDataType);
1646 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1647 p_b_scale_grid_up + expert_id * expert_scale_stride /
sizeof(BScaleDataType),
1648 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1653 decltype(b_scale_grid_desc_bn_ak),
1654 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1661 b_scale_grid_desc_bn_ak,
1666 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1668 a_grid_desc_ak0_m_ak1,
1669 a_block_desc_ak0_m_ak1,
1673 a_block_slice_copy_step,
1675 b_grid_desc_bk0_n_bk1,
1676 b_block_desc_bk0_n_bk1,
1678 b_blockwise_copy_up,
1683 b_block_slice_copy_step,
1688 a_scale_grid_desc_am_ak,
1689 a_scale_thread_copy,
1692 b_scale_grid_desc_bn_ak,
1693 b_scale_thread_copy,
1694 b_scale_thread_copy_up,
1696 b_scale_grid_buf_up,
1697 num_k_block_main_loop);
1701 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1702 a_grid_desc_ak0_m_ak1,
1703 a_block_desc_ak0_m_ak1,
1707 a_block_slice_copy_step,
1708 b_grid_desc_bk0_n_bk1,
1709 b_block_desc_bk0_n_bk1,
1713 b_block_slice_copy_step,
1715 a_scale_grid_desc_am_ak,
1716 a_scale_thread_copy,
1718 b_scale_grid_desc_bn_ak,
1719 b_scale_thread_copy,
1721 num_k_block_main_loop);
1726 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1727 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1729 static_assert(CShuffleMXdlPerWavePerShuffle %
MXdlPack == 0 &&
1730 CShuffleNXdlPerWavePerShuffle %
NXdlPack == 0,
1733 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1734 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1737 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1738 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1742 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1743 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1745 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1746 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1747 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1748 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1749 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1750 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1751 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1752 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1753 constexpr
auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
1754 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
1757 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1758 static_assert(M5 == 4);
1768 const index_t m_pos = block_m_id * MPerBlock +
1769 m0 * M2 * M1 * M3 * M4 * M5 +
1770 m1 * M2 * M3 * M4 * M5 +
1771 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1773 if constexpr(MulRoutedWeight)
1776 *c_style_pointer_cast<const vector_type<float, M5>*>(
1777 p_ds_grid[
I2] + m_pos);
1781 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1782 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1785 if constexpr(IsInputGemm)
1787 if constexpr(ActivationOperation ==
1790 float gate = c_thread_buf[cidx];
1791 float up = c_thread_buf_up[cidx];
1792 if constexpr(MulRoutedWeight)
1794 gate = gate * topk_weights.AsType<
float>()[m5];
1795 up = up * topk_weights.AsType<
float>()[m5];
1798 c_thread_buf_fp32(cidx) = gate * up;
1802 float gate = c_thread_buf[cidx];
1803 float up = c_thread_buf_up[cidx];
1804 if constexpr(MulRoutedWeight)
1806 gate = gate * topk_weights.AsType<
float>()[m5];
1807 up = up * topk_weights.AsType<
float>()[m5];
1810 c_thread_buf_fp32(cidx) = gate * up;
1825 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1826 if constexpr(MulRoutedWeight)
1828 c_thread_buf_fp32(cidx) =
1829 topk_weights.AsType<
float>()[m5] *
1830 c_thread_buf_fp32[cidx];
1840 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1843 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1844 static_cast<CShuffleDataType*
>(p_shared),
1845 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1848 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1874 const auto c_thread_mtx_on_block =
1875 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1877 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1878 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1880 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1886 const auto m_thread_data_on_block_idx =
1887 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1890 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1896 const auto n_thread_data_on_block_idx =
1897 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1904 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1905 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1908 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
1917 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1922 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1925 m_thread_data_on_block_idx[
I1],
1926 n_thread_data_on_block_idx[
I1],
1927 m_thread_data_on_block_idx[
I2],
1928 n_thread_data_on_block_idx[
I2],
1929 m_thread_data_on_block_idx[
I3],
1930 m_thread_data_on_block_idx[
I4],
1931 m_thread_data_on_block_idx[
I5],
1932 n_thread_data_on_block_idx[
I3]),
1935 using EDataType = CDataType;
1940 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1946 return make_dynamic_buffer<AddressSpaceEnum::Global>(
1947 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1953 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1955 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1960 tie(c_shuffle_block_buf),
1962 {
return ds_grid_buf[i]; },
1966 const auto idx_c_ds_block_begin =
1976 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1977 c_grid_desc_mblock_mperblock_nblock_nperblock;
1979 using CDEBlockTransferCluster =
1980 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1981 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1982 constexpr
index_t scatter_weight_idx = 3;
1987 decltype(c_ds_desc_refs),
1988 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1989 CElementwiseOperation,
1994 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1996 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1997 CDEBlockTransferCluster,
2003 CDEShuffleBlockTransferScalarPerVectors,
2015 idx_c_ds_block_begin,
2016 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2020 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2021 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2023 constexpr
auto sfc_c_vgpr =
2034 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2036 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2046 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2049 constexpr
auto sfc_cde_block =
2053 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2055 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2057 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2058 constexpr
auto EMThreads =
2059 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2060 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2061 constexpr
auto ENThreads =
2062 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2067 auto dstidx = sfc_cde_block.GetIndex(access_id);
2069 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2071 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2072 IndexType token_offset = fused_token & 0xffffff;
2073 if constexpr(IsInputGemm)
2075 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2077 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
N;
2083 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2084 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2086 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2087 c_shuffle_block_buf);
2093 cde_block_copy_lds_and_global.Run(
2096 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2100 if constexpr(access_id < num_access - 1)
2102 constexpr
auto cde_lds_and_global_step =
2103 sfc_cde_block.GetForwardStep(access_id);
2107 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2108 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2112 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2113 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2115 cde_lds_and_global_step);
2122 template <
bool HasMainKBlockLoop,
2125 __device__
static void Run_2Lds(
const index_t* p_sorted_token_ids,
2126 const index_t* p_sorted_expert_ids,
2127 const index_t* p_max_token_id,
2128 const ADataType* p_a_grid,
2129 const AScaleDataType* p_a_scale_grid,
2130 const BDataType* p_b_grid,
2131 const BScaleDataType* p_b_scale_grid,
2133 CDataType* p_c_grid,
2136 const Problem& problem,
2137 AElementwiseOperation a_element_op,
2138 BElementwiseOperation b_element_op,
2139 CElementwiseOperation c_element_op)
2143 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2150 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2151 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2152 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2170 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2172 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2173 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2175 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2176 if(expert_block_id * MPerBlock >= max_token_id)
2179 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2180 const auto block_mn = [&]() -> std::pair<int, int> {
2181 if constexpr(NSwizzle)
2183 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2184 const index_t prefix_block = ecnt_prefix * problem.NBlock;
2185 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2186 const index_t expert_swizzle =
2187 ecnt > 0 ? ecnt : 1;
2188 const index_t bid_new = blockIdx.x - prefix_block;
2189 const index_t nid = __builtin_amdgcn_readfirstlane(
2190 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2192 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2197 return {blockIdx.x, blockIdx.y};
2201 const index_t block_n_id = block_mn.first;
2202 const index_t block_m_id = block_mn.second;
2204 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2207 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
2208 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
2209 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
2210 constexpr
auto AKThreads = AK0Threads * AK1Threads;
2211 constexpr
auto AMRepeats = MPerBlock / AMThreads;
2212 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2214 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2216 StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
2217 static_for<0, AMRepeats, 1>{}([&](
auto m0) {
2218 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2219 index_t token_offset = fused_token & 0xffffff;
2220 if constexpr(!IsInputGemm)
2222 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2224 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.K;
2228 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2229 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2233 const index_t n_block_data_idx_on_grid =
2234 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2236 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2237 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2239 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2240 p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2242 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2243 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2244 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2245 p_b_scale_grid + (expert_id * expert_scale_stride) /
sizeof(BScaleDataType),
2246 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2255 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2257 AElementwiseOperation,
2260 Sequence<AK0Number, MPerBlock, AK1Number>,
2261 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2262 ABlockTransferThreadClusterArrangeOrder,
2265 decltype(a_grid_desc_ak0_m_ak1),
2266 decltype(a_block_desc_ak0_m_ak1),
2267 ABlockTransferSrcAccessOrder,
2269 ABlockTransferSrcVectorDim,
2271 ABlockTransferSrcScalarPerVector,
2272 ABlockTransferDstScalarPerVector_AK1,
2275 AThreadTransferSrcResetCoordinateAfterRun,
2279 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2282 a_block_desc_ak0_m_ak1,
2289 auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2290 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2291 auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2292 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2293 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2295 auto b_blockwise_copy =
2296 ThreadwiseTensorSliceTransfer_v2<BDataType,
2298 decltype(b_grid_desc_bpreshuffled),
2299 decltype(b_block_desc_bk0_n_bk1),
2304 Number<BK1Value>{}>,
2305 Sequence<1, 2, 0, 3, 4>,
2307 BBlockTransferSrcScalarPerVector,
2308 BThreadTransferSrcResetCoordinateAfterRun,
2310 b_grid_desc_bpreshuffled,
2319 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2320 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2321 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2322 static_cast<ADataType*
>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2323 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2326 constexpr
auto b_block_slice_copy_step =
make_multi_index(0, 0, 0, KRepeat, 0);
2329 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2331 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2332 decltype(c_thread_buf) c_thread_buf_up;
2336 c_thread_buf.num_of_v_,
2337 c_thread_buf.s_per_v,
2341 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2342 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2346 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2347 const auto waveId_m = wave_idx[
I0];
2348 const auto waveId_n = wave_idx[
I1];
2350 auto thread_offset_shuffled =
2353 auto a_thread_offset_m = waveId_m;
2356 const index_t token_scale_pos = block_m_id * MPerBlock;
2357 if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2360 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2363 decltype(a_scale_grid_desc_am_ak),
2364 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2370 true>(a_scale_grid_desc_am_ak,
2376 auto b_thread_offset_n = waveId_n;
2378 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2381 decltype(b_scale_grid_desc_bn_ak),
2382 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2388 true>(b_scale_grid_desc_bn_ak,
2393 if constexpr(IsInputGemm)
2395 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
2396 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2397 p_b_grid_up + expert_id * expert_stride /
BPackedSize,
2398 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2399 auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2402 decltype(b_grid_desc_bpreshuffled),
2403 decltype(b_block_desc_bk0_n_bk1),
2404 Sequence<Number<NXdlPerWave>{},
I1, Number<KRepeat>{}, Number<BK1Value>{}>,
2405 Sequence<1, 2, 0, 3>,
2407 BBlockTransferSrcScalarPerVector,
2408 BThreadTransferSrcResetCoordinateAfterRun,
2409 true>(b_grid_desc_bpreshuffled,
2414 const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
2415 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2416 p_b_scale_grid_up + expert_id * expert_scale_stride,
2417 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2418 auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2421 decltype(b_scale_grid_desc_bn_ak),
2422 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2429 b_scale_grid_desc_bn_ak,
2434 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2435 a_grid_desc_ak0_m_ak1,
2436 a_block_desc_ak0_m_ak1,
2440 a_block_slice_copy_step,
2441 b_grid_desc_bpreshuffled,
2442 b_block_desc_bk0_n_bk1,
2444 b_blockwise_copy_up,
2448 b_block_slice_copy_step,
2451 a_scale_grid_desc_am_ak,
2452 a_scale_thread_copy,
2454 b_scale_grid_desc_bn_ak,
2455 b_scale_thread_copy,
2456 b_scale_thread_copy_up,
2458 b_scale_grid_buf_up,
2459 num_k_block_main_loop);
2463 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2464 a_grid_desc_ak0_m_ak1,
2465 a_block_desc_ak0_m_ak1,
2469 a_block_slice_copy_step,
2470 b_grid_desc_bpreshuffled,
2471 b_block_desc_bk0_n_bk1,
2475 b_block_slice_copy_step,
2477 a_scale_grid_desc_am_ak,
2478 a_scale_thread_copy,
2480 b_scale_grid_desc_bn_ak,
2481 b_scale_thread_copy,
2483 num_k_block_main_loop);
2488 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2489 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2493 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2494 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2498 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2499 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2501 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2502 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2503 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2504 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2505 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2506 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2507 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2508 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2512 static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2513 static_assert(M4 == 4);
2517 vector_type<float, 4> topk_weights;
2518 static_for<0, NXdlPerWave, 1>{}([&](
auto n0) {
2519 static_for<0, MXdlPerWave, 1>{}([&](
auto m0) {
2520 static_for<0, M2, 1>{}([&](
auto m2) {
2521 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2522 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2523 if constexpr(MulRoutedWeight)
2525 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2526 p_ds_grid[
I2] + m_pos);
2528 static_for<0, M4, 1>{}([&](
auto m4) {
2530 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2536 constexpr
auto cidx = Number<c_offset>{};
2538 if constexpr(IsInputGemm)
2542 float gate = c_thread_buf[cidx];
2543 float up = c_thread_buf_up[cidx];
2544 if constexpr(MulRoutedWeight)
2546 gate = gate * topk_weights.AsType<
float>()[m4];
2547 up = up * topk_weights.AsType<
float>()[m4];
2549 tensor_operation::element_wise::Silu{}(gate, gate);
2550 c_thread_buf_fp32(cidx) = gate * up;
2554 float gate = c_thread_buf[cidx];
2555 float up = c_thread_buf_up[cidx];
2556 if constexpr(MulRoutedWeight)
2558 gate = gate * topk_weights.AsType<
float>()[m4];
2559 up = up * topk_weights.AsType<
float>()[m4];
2561 tensor_operation::element_wise::Gelu{}(gate, gate);
2562 c_thread_buf_fp32(cidx) = gate * up;
2567 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2568 if constexpr(MulRoutedWeight)
2570 c_thread_buf_fp32(cidx) =
2571 topk_weights.AsType<
float>()[m4] * c_thread_buf_fp32[cidx];
2579 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2582 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2583 static_cast<CShuffleDataType*
>(p_shared),
2584 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2587 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2590 Number<CShuffleMXdlPerWavePerShuffle>{},
2598 Number<CShuffleNXdlPerWavePerShuffle>{},
2602 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
2604 Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
2608 const auto c_thread_mtx_on_block =
2609 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2611 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2612 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2614 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2620 const auto m_thread_data_on_block_idx =
2621 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2624 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2630 const auto n_thread_data_on_block_idx =
2631 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2635 auto c_thread_copy_vgpr_to_lds =
2636 ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
2638 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2639 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2641 Sequence<CShuffleMXdlPerWavePerShuffle,
2642 CShuffleNXdlPerWavePerShuffle,
2649 Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2655 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2658 m_thread_data_on_block_idx[
I1],
2659 n_thread_data_on_block_idx[
I1],
2660 m_thread_data_on_block_idx[
I2],
2661 m_thread_data_on_block_idx[
I3],
2662 m_thread_data_on_block_idx[
I4],
2663 n_thread_data_on_block_idx[
I2]),
2666 using EDataType = CDataType;
2669 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2671 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2673 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2677 return make_dynamic_buffer<AddressSpaceEnum::Global>(
2678 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2680 Number<NumDTensor>{});
2684 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2686 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2687 Number<NumDTensor>{}));
2691 tie(c_shuffle_block_buf),
2693 {
return ds_grid_buf[i]; },
2694 Number<NumDTensor>{}));
2697 const auto idx_c_ds_block_begin =
2705 Number<NumDTensor>{}));
2707 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2708 c_grid_desc_mblock_mperblock_nblock_nperblock;
2710 using CDEBlockTransferCluster =
2711 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2712 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2713 constexpr
index_t scatter_weight_idx = 3;
2714 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2718 decltype(c_ds_desc_refs),
2719 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2720 CElementwiseOperation,
2721 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>,
2725 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2727 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2728 CDEBlockTransferCluster,
2729 Sequence<0, 1, 2, 3>,
2730 Sequence<0, 1, 2, 3>,
2731 Sequence<0, 1, 2, 3>,
2734 CDEShuffleBlockTransferScalarPerVectors,
2746 idx_c_ds_block_begin,
2747 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2751 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2752 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2753 constexpr
auto sfc_c_vgpr =
2754 SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
2755 Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2756 Sequence<CShuffleMXdlPerWavePerShuffle,
2757 CShuffleNXdlPerWavePerShuffle,
2765 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2768 constexpr
auto sfc_cde_block =
2769 SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
2770 Sequence<0, 2, 1, 3>,
2772 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2774 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2776 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2777 constexpr
auto EMThreads =
2778 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2779 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2780 constexpr
auto ENThreads =
2781 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2782 static_for<0, num_access, 1>{}([&](
auto access_id) {
2784 StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
2786 auto dstidx = sfc_cde_block.GetIndex(access_id);
2788 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2789 static_for<0, EMRepeats, 1>{}([&](
auto m0) {
2790 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2791 IndexType token_offset = fused_token & 0xffffff;
2792 if constexpr(IsInputGemm)
2794 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2796 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.N;
2802 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2803 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2805 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2806 c_shuffle_block_buf);
2812 cde_block_copy_lds_and_global.Run(
2815 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2819 if constexpr(access_id < num_access - 1)
2821 constexpr
auto cde_lds_and_global_step =
2822 sfc_cde_block.GetForwardStep(access_id);
2825 static_for<0, NumDTensor, 1>{}([&](
auto i) {
2826 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2827 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2831 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2832 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2834 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:56
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
constexpr auto BlockGemmMXNBSPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp:37
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: gridwise_moe_mx_gemm_bns.hpp:648
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:711
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm_bns.hpp:708
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm_bns.hpp:709
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm_bns.hpp:710
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:715
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:720
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:716
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:718
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:714
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:713
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:712
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm_bns.hpp:649
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:719
Definition: gridwise_moe_mx_gemm_bns.hpp:576
index_t M
Definition: gridwise_moe_mx_gemm_bns.hpp:626
index_t TopK
Definition: gridwise_moe_mx_gemm_bns.hpp:625
index_t NPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:637
index_t MPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:636
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm_bns.hpp:632
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm_bns.hpp:630
index_t MBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:642
index_t StrideC
Definition: gridwise_moe_mx_gemm_bns.hpp:634
index_t AK0
Definition: gridwise_moe_mx_gemm_bns.hpp:640
index_t KPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:639
index_t NBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:643
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm_bns.hpp:577
index_t StrideA
Definition: gridwise_moe_mx_gemm_bns.hpp:629
index_t StrideB
Definition: gridwise_moe_mx_gemm_bns.hpp:631
index_t KBatch
Definition: gridwise_moe_mx_gemm_bns.hpp:635
index_t BK0
Definition: gridwise_moe_mx_gemm_bns.hpp:641
index_t KRead
Definition: gridwise_moe_mx_gemm_bns.hpp:638
__host__ void Print() const
Definition: gridwise_moe_mx_gemm_bns.hpp:612
index_t K
Definition: gridwise_moe_mx_gemm_bns.hpp:628
index_t N
Definition: gridwise_moe_mx_gemm_bns.hpp:627
index_t NumTokens
Definition: gridwise_moe_mx_gemm_bns.hpp:624
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm_bns.hpp:633
Definition: gridwise_moe_mx_gemm_bns.hpp:724
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:778
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:779
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm_bns.hpp:725
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:781
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:780
Definition: gridwise_moe_mx_gemm_bns.hpp:173
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm_bns.hpp:206
static constexpr auto I1
Definition: gridwise_moe_mx_gemm_bns.hpp:178
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm_bns.hpp:1315
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm_bns.hpp:196
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm_bns.hpp:244
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:198
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm_bns.hpp:192
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm_bns.hpp:194
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm_bns.hpp:174
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:1281
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm_bns.hpp:205
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:232
static constexpr auto I5
Definition: gridwise_moe_mx_gemm_bns.hpp:182
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bns.hpp:397
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm_bns.hpp:175
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:1288
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm_bns.hpp:230
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bns.hpp:315
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:1103
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bns.hpp:552
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bns.hpp:1325
static constexpr auto I4
Definition: gridwise_moe_mx_gemm_bns.hpp:181
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm_bns.hpp:1065
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm_bns.hpp:202
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm_bns.hpp:219
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bns.hpp:564
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:249
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:259
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:498
static constexpr auto I8
Definition: gridwise_moe_mx_gemm_bns.hpp:185
static constexpr auto I7
Definition: gridwise_moe_mx_gemm_bns.hpp:184
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:200
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bns.hpp:1296
static constexpr auto I2
Definition: gridwise_moe_mx_gemm_bns.hpp:179
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:488
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm_bns.hpp:1316
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:277
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:299
remove_cvref_t< decltype(BlockGemmMXNBSPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm_bns.hpp:1063
static constexpr auto I6
Definition: gridwise_moe_mx_gemm_bns.hpp:183
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm_bns.hpp:213
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:265
static constexpr auto I3
Definition: gridwise_moe_mx_gemm_bns.hpp:180
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm_bns.hpp:193
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm_bns.hpp:904
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm_bns.hpp:217
static constexpr auto I9
Definition: gridwise_moe_mx_gemm_bns.hpp:186
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm_bns.hpp:531
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm_bns.hpp:191
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm_bns.hpp:1020
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm_bns.hpp:284
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:188
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:254
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:199
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:289
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm_bns.hpp:203
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm_bns.hpp:507
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:271
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm_bns.hpp:784
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:234
static constexpr auto I0
Definition: gridwise_moe_mx_gemm_bns.hpp:177
Definition: xdlops_gemm.hpp:1126
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1647
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: data_type.hpp:41
Definition: integral_constant.hpp:20
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129