39 template <
typename GridwiseGemm,
40 bool HasMainKBlockLoop,
45 #if CK_USE_LAUNCH_BOUNDS
52 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
54 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
56 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
58 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
59 karg.p_sorted_token_ids,
60 karg.p_sorted_expert_ids,
62 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
63 karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
64 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
65 karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
80 template <
typename GridwiseGemm,
81 bool HasMainKBlockLoop,
86 #if CK_USE_LAUNCH_BOUNDS
93 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
95 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
96 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
98 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
100 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
101 karg.p_sorted_token_ids,
102 karg.p_sorted_expert_ids,
104 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
105 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
106 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
107 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
122 template <
typename ALayout,
127 typename AScaleDataType,
129 typename BScaleDataType,
130 typename AccDataType,
131 typename CShuffleDataType,
134 typename AElementwiseOperation,
135 typename BElementwiseOperation,
136 typename CElementwiseOperation,
149 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
150 typename ABlockTransferThreadClusterArrangeOrder,
151 typename ABlockTransferSrcAccessOrder,
152 index_t ABlockTransferSrcVectorDim,
153 index_t ABlockTransferSrcScalarPerVector,
154 index_t ABlockTransferDstScalarPerVector_AK1,
155 bool AThreadTransferSrcResetCoordinateAfterRun,
157 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
158 typename BBlockTransferThreadClusterArrangeOrder,
159 typename BBlockTransferSrcAccessOrder,
160 index_t BBlockTransferSrcVectorDim,
161 index_t BBlockTransferSrcScalarPerVector,
162 index_t BBlockTransferDstScalarPerVector_BK1,
163 bool BThreadTransferSrcResetCoordinateAfterRun,
165 index_t CShuffleMXdlPerWavePerShuffle,
166 index_t CShuffleNXdlPerWavePerShuffle,
167 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
168 typename CDEShuffleBlockTransferScalarPerVectors,
171 index_t ActivationOperation = 0,
172 bool NSwizzle =
false,
173 bool IsInputGemm =
true,
174 bool MulRoutedWeight =
true,
176 typename ComputeTypeA = ADataType,
177 typename ComputeTypeB = BDataType>
195 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
239 return static_cast<const DDataType*
>(
nullptr);
252 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
253 const index_t gridy = NSwizzle ? 1 : mblock;
275 auto K_t = K_Batch * KPerBlock;
276 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
281 auto K_t = K_Batch * KPerBlock;
282 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
287 auto K_t = K_Batch * KPerBlock;
288 return (K + K_t - 1) / K_t * KPerBlock;
294 auto K_t = K_Batch * KReadVec;
295 return (K + K_t - 1) / K_t * KReadVec;
308 template <
index_t MNXdlPerWave,
312 typename TileDesc_K0_MN_K1>
338 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
340 const auto a_grid_desc_mraw_kraw = [&]() {
341 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
345 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
353 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
354 GemmSpec == GemmSpecialization::MNKPadding)
357 const auto a_grid_desc_m_k =
371 return a_grid_desc_ak0_m_ak1;
373 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
374 GemmSpec == GemmSpecialization::MNPadding)
378 a_grid_desc_mraw_kraw,
385 a_grid_desc_ak0_m_ak1,
393 a_grid_desc_permuted,
402 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
403 GemmSpec == GemmSpecialization::NKPadding)
407 a_grid_desc_mraw_kraw,
419 return a_grid_desc_ak0_m_ak1;
425 a_grid_desc_mraw_kraw,
432 a_grid_desc_ak0_m_ak1,
440 a_grid_desc_permuted,
455 const auto b_grid_desc_nraw_kraw = [&]() {
469 GemmSpec != GemmSpecialization::Default),
470 "pk_i4_t does not support padding");
472 (GemmSpec != GemmSpecialization::Default &&
473 GemmSpec != GemmSpecialization::MPadding)),
474 "f4x2_pk_t does not support K padding");
476 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
477 GemmSpec == GemmSpecialization::MNKPadding)
480 const auto b_grid_desc_n_k =
494 return b_grid_desc_bk0_n_bk1;
496 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
497 GemmSpec == GemmSpecialization::MNPadding)
501 b_grid_desc_nraw_kraw,
507 return b_grid_desc_bk0_n_bk1;
509 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
510 GemmSpec == GemmSpecialization::MKPadding)
514 b_grid_desc_nraw_kraw,
526 return b_grid_desc_bk0_n_bk1;
532 b_grid_desc_nraw_kraw,
539 b_grid_desc_bk0_n_bk1,
547 b_grid_desc_permuted,
559 template <
typename ABlockDesc_AK0_M_AK1>
560 __host__ __device__
static constexpr
auto
563 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
565 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl>(
566 ABlockDesc_AK0_M_AK1{});
569 template <
typename BBlockDesc_BK0_N_BK1>
570 __host__ __device__
static constexpr
auto
573 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
575 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl>(
576 BBlockDesc_BK0_N_BK1{});
579 template <
typename ELayout>
581 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
583 const auto c_grid_desc_mraw_nraw = [&]() {
602 template <
typename DLayout>
603 __host__ __device__
static auto
606 const auto c_grid_desc_mraw_nraw = [&]() {
631 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
636 template <
typename DsGr
idDesc>
638 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
643 ds_grid_desc_m_n[i], MBlock, NBlock);
659 std::array<index_t, NumDTensor> StrideDs_,
687 std::cout <<
"problem {" <<
"NumTokens:" <<
NumTokens <<
", " <<
"TopK:" <<
TopK <<
", "
688 <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
692 <<
", " <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", "
693 <<
"AK0:" <<
AK0 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock
694 <<
", " <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
723 const index_t* p_sorted_expert_ids_,
724 const index_t* p_max_token_id_,
725 const ADataType* p_a_grid_,
726 const AScaleDataType* p_a_scale_grid_,
727 const BDataType* p_b_grid_,
728 const BScaleDataType* p_b_scale_grid_,
729 std::array<const void*, NumDTensor> p_ds_grid_,
730 CDataType* p_c_grid_,
740 std::array<index_t, NumDTensor> StrideDs_,
743 AElementwiseOperation a_element_op_,
744 BElementwiseOperation b_element_op_,
745 CElementwiseOperation c_element_op_)
777 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
800 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
804 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
809 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
813 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
820 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
824 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
831 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
836 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
841 if(k_id < karg.
KBatch - 1)
859 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
860 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
861 constexpr
index_t WaveSize = BlockSize / (MWave * NWave);
875 constexpr
auto a_lds_block_desc =
887 return a_lds_block_desc_permuted;
894 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
895 constexpr
auto M1 = MPerBlock / M0;
897 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
898 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
899 constexpr
auto KThreadRead = WaveSize / MPerXdl;
900 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
902 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
904 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
905 constexpr
auto KThreadReadPerm =
906 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
907 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
911 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
913 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
915 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
921 Number<kfold * M0 / mpair>{},
940 a_lds_block_desc_permuted,
962 a_lds_block_desc_unmerged,
965 Number<KThreadWrite / kfold / KThreadReadPerm>{},
974 return a_lds_block_desc_ak0_m_ak1;
980 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
981 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
982 constexpr
index_t WaveSize = BlockSize / (MWave * NWave);
995 constexpr
auto b_lds_block_desc =
1007 return b_lds_block_desc_permuted;
1011 constexpr
auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
1012 constexpr
auto N1 = NPerBlock / N0;
1014 constexpr
auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
1015 constexpr
auto K0PerThreadWrite =
BK0Number / KThreadWrite;
1016 constexpr
auto KThreadRead = WaveSize / NPerXdl;
1017 constexpr
auto K0PerThreadRead =
BK0Number / KThreadRead;
1019 constexpr
auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
1021 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
1022 constexpr
auto KThreadReadPerm =
1023 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1024 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1028 constexpr
auto npair = (
BK1Number * NPerXdl *
sizeof(BDataType) > 128)
1030 : ((128 / (
BK1Number * NPerXdl *
sizeof(BDataType))) > N0
1032 : 128 / (
BK1Number * NPerXdl *
sizeof(BDataType)));
1038 Number<kfold * N0 / npair>{},
1057 b_lds_block_desc_permuted,
1079 b_lds_block_desc_unmerged,
1082 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1091 return b_lds_block_desc_bk0_n_bk1;
1097 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1098 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1100 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1107 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1128 ABlockTransferSrcScalarPerVector,
1129 BBlockTransferSrcScalarPerVector,
1150 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1153 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1156 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1159 constexpr
auto c_block_size =
1160 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1162 if constexpr(IsInputGemm)
1164 return math::max(a_block_space_size_aligned *
sizeof(ADataType) +
1165 b_block_space_size_aligned *
sizeof(BDataType) * 2,
1166 c_block_size *
sizeof(CShuffleDataType));
1170 return math::max((a_block_space_size_aligned *
sizeof(ADataType) +
1171 b_block_space_size_aligned *
sizeof(BDataType)),
1172 c_block_size *
sizeof(CShuffleDataType));
1181 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1182 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1183 "Invalid tuning param!");
1185 static_assert(KPerBlock % (ScaleBlockSize /
BPackedSize) == 0,
1186 "KPerBlock should be multiple of ScaleBlockSize");
1194 if(!(karg.M % MPerBlock == 0))
1198 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
1199 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1212 if(!(karg.N % NPerBlock == 0))
1216 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
1217 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1229 auto K_t = karg.KBatch * KPerBlock;
1230 if(!(karg.K % K_t == 0))
1234 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1235 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1236 <<
", in function: " << __func__ << std::endl;
1244 auto K_t = karg.KBatch * KReadVec;
1246 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1254 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1258 std::cout <<
"Arg K (" << karg.K
1259 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1260 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1261 << __LINE__ <<
", in function: " << __func__ << std::endl;
1268 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1272 std::cout <<
"Arg M (" << karg.M
1273 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1274 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1275 << __LINE__ <<
", in function: " << __func__ << std::endl;
1283 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1287 std::cout <<
"Arg N (" << karg.N
1288 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1289 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1290 << __LINE__ <<
", in function: " << __func__ << std::endl;
1297 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1301 std::cout <<
"Arg K (" << karg.K
1302 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1303 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1304 << __LINE__ <<
", in function: " << __func__ << std::endl;
1316 std::cout <<
"Arg N (" << karg.N
1317 <<
") value is not a multiple of "
1318 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1320 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1332 std::cout <<
"Arg M (" << karg.M
1333 <<
") value is not a multiple of "
1334 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1336 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1346 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1348 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1359 const index_t num_loop = K / KPerBlock;
1361 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1366 const index_t num_loop = K / KPerBlock;
1368 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1371 template <
typename CGr
idDesc>
1373 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1382 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1394 "A scale pack data type too large!");
1396 "B scale pack data type too large!");
1398 static_assert(is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
1399 is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>,
1400 "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!");
1403 template <
bool HasMainKBlockLoop,
1406 __device__
static void Run(
const index_t* p_sorted_token_ids,
1407 const index_t* p_sorted_expert_ids,
1408 const index_t* p_max_token_id,
1409 const ADataType* p_a_grid,
1410 const AScaleDataType* p_a_scale_grid,
1411 const BDataType* p_b_grid,
1412 const BScaleDataType* p_b_scale_grid,
1414 CDataType* p_c_grid,
1417 AElementwiseOperation a_element_op,
1418 BElementwiseOperation b_element_op,
1419 CElementwiseOperation c_element_op)
1432 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1451 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1455 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1456 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1457 if(expert_block_id * MPerBlock >= max_token_id)
1460 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1462 const auto block_mn = [&]() -> std::pair<int, int> {
1463 if constexpr(NSwizzle)
1465 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1467 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1468 const index_t expert_swizzle =
1469 ecnt > 0 ? ecnt : 1;
1470 const index_t bid_new = blockIdx.x - prefix_block;
1471 const index_t nid = __builtin_amdgcn_readfirstlane(
1472 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1474 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1479 return {blockIdx.x, blockIdx.y};
1483 const index_t block_n_id = block_mn.first;
1484 const index_t block_m_id = block_mn.second;
1486 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1489 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1490 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1491 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1492 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1493 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1494 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1496 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
1498 StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
1499 static_for<0, AMRepeats, 1>{}([&](
auto m0) {
1500 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1501 index_t token_offset = fused_token & 0xffffff;
1502 if constexpr(!IsInputGemm)
1504 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1506 gather_offsets(m0) =
static_cast<IndexType
>(token_offset);
1510 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
1511 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1512 problem.
N * (IsInputGemm ? 2 : 1) *
1516 const index_t n_block_data_idx_on_grid =
1517 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1520 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1521 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1522 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1523 p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1526 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1527 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1528 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1529 p_b_scale_grid + (expert_id * expert_scale_stride) /
sizeof(BScaleDataType),
1530 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1542 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
1544 Sequence<AK0Number, MPerBlock, AK1Number>,
1545 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1546 ABlockTransferThreadClusterArrangeOrder,
1549 decltype(a_grid_desc_ak0_m_ak1),
1550 decltype(a_block_desc_ak0_m_ak1),
1551 ABlockTransferSrcAccessOrder,
1552 ABlockTransferSrcVectorDim,
1554 ABlockTransferSrcScalarPerVector,
1556 1>(a_grid_desc_ak0_m_ak1,
1558 a_block_desc_ak0_m_ak1,
1563 auto b_blockwise_copy =
1565 Sequence<BK0Number, NPerBlock, BK1Number>,
1566 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1567 BBlockTransferThreadClusterArrangeOrder,
1570 decltype(b_grid_desc_bk0_n_bk1),
1571 decltype(b_block_desc_bk0_n_bk1),
1572 BBlockTransferSrcAccessOrder,
1573 BBlockTransferSrcVectorDim,
1575 BBlockTransferSrcScalarPerVector>(
1576 b_grid_desc_bk0_n_bk1,
1578 b_block_desc_bk0_n_bk1,
1583 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1586 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1587 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1589 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1590 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) +
1591 a_block_space_size_aligned *
sizeof(ADataType)),
1592 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1598 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1600 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1601 decltype(c_thread_buf) c_thread_buf_up;
1605 c_thread_buf.num_of_v_,
1606 c_thread_buf.s_per_v,
1610 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1611 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1615 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1616 const auto waveId_m = wave_idx[
I0];
1617 const auto waveId_n = wave_idx[
I1];
1619 auto thread_offset_shuffled =
1622 auto a_thread_offset_m = waveId_m;
1624 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1627 decltype(a_scale_grid_desc_am_ak),
1628 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1634 true>(a_scale_grid_desc_am_ak,
1640 auto b_thread_offset_n = waveId_n;
1642 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1645 decltype(b_scale_grid_desc_bn_ak),
1646 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1652 true>(b_scale_grid_desc_bn_ak,
1657 if constexpr(IsInputGemm)
1660 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1661 auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1662 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) +
1663 a_block_space_size_aligned *
sizeof(ADataType) +
1664 b_block_space_size_aligned *
sizeof(BDataType)),
1665 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1667 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1668 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1669 p_b_grid_up + expert_id * expert_stride,
1670 b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1672 auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
1674 Sequence<BK0Number, NPerBlock, BK1Number>,
1675 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1676 BBlockTransferThreadClusterArrangeOrder,
1679 decltype(b_grid_desc_bk0_n_bk1),
1680 decltype(b_block_desc_bk0_n_bk1),
1681 BBlockTransferSrcAccessOrder,
1682 BBlockTransferSrcVectorDim,
1684 BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
1686 b_block_desc_bk0_n_bk1,
1689 const BScaleDataType* p_b_scale_grid_up =
1690 p_b_scale_grid + expert_scale_stride / 2 /
sizeof(BScaleDataType);
1691 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1692 p_b_scale_grid_up + expert_id * expert_scale_stride /
sizeof(BScaleDataType),
1693 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1695 auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1698 decltype(b_scale_grid_desc_bn_ak),
1699 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1706 b_scale_grid_desc_bn_ak,
1711 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1713 a_grid_desc_ak0_m_ak1,
1714 a_block_desc_ak0_m_ak1,
1718 a_block_slice_copy_step,
1720 b_grid_desc_bk0_n_bk1,
1721 b_block_desc_bk0_n_bk1,
1723 b_blockwise_copy_up,
1728 b_block_slice_copy_step,
1733 a_scale_grid_desc_am_ak,
1734 a_scale_thread_copy,
1737 b_scale_grid_desc_bn_ak,
1738 b_scale_thread_copy,
1739 b_scale_thread_copy_up,
1741 b_scale_grid_buf_up,
1742 num_k_block_main_loop);
1746 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1747 a_grid_desc_ak0_m_ak1,
1748 a_block_desc_ak0_m_ak1,
1752 a_block_slice_copy_step,
1753 b_grid_desc_bk0_n_bk1,
1754 b_block_desc_bk0_n_bk1,
1758 b_block_slice_copy_step,
1760 a_scale_grid_desc_am_ak,
1761 a_scale_thread_copy,
1763 b_scale_grid_desc_bn_ak,
1764 b_scale_thread_copy,
1766 num_k_block_main_loop);
1771 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1772 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1774 static_assert(CShuffleMXdlPerWavePerShuffle %
MXdlPack == 0 &&
1775 CShuffleNXdlPerWavePerShuffle %
NXdlPack == 0,
1778 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1779 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1782 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1783 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1787 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1788 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1790 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1791 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1792 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1793 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1794 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1795 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1796 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1797 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1798 constexpr
auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
1799 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
1802 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1803 static_assert(M5 == 4);
1807 vector_type<float, 4> topk_weights;
1808 static_for<0, NXdlPerWave /
NXdlPack, 1>{}([&](
auto n0) {
1809 static_for<0, NXdlPack, 1>{}([&](
auto inxdl) {
1810 static_for<0, MXdlPerWave /
MXdlPack, 1>{}([&](
auto m0) {
1811 static_for<0, MXdlPack, 1>{}([&](
auto imxdl) {
1812 static_for<0, M3, 1>{}([&](
auto m3) {
1813 const index_t m_pos = block_m_id * MPerBlock +
1814 m0 * M2 * M1 * M3 * M4 * M5 +
1815 m1 * M2 * M3 * M4 * M5 +
1816 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1817 if constexpr(MulRoutedWeight)
1820 *c_style_pointer_cast<const vector_type<float, M5>*>(
1821 p_ds_grid[
I2] + m_pos);
1823 static_for<0, M5, 1>{}([&](
auto m5) {
1825 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1826 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1827 constexpr
auto cidx = Number<c_offset>{};
1829 if constexpr(IsInputGemm)
1831 if constexpr(ActivationOperation ==
1834 float gate = c_thread_buf[cidx];
1835 float up = c_thread_buf_up[cidx];
1836 if constexpr(MulRoutedWeight)
1838 gate = gate * topk_weights.AsType<
float>()[m5];
1839 up = up * topk_weights.AsType<
float>()[m5];
1841 tensor_operation::element_wise::Silu{}(gate, gate);
1842 c_thread_buf_fp32(cidx) = gate * up;
1846 float gate = c_thread_buf[cidx];
1847 float up = c_thread_buf_up[cidx];
1848 if constexpr(MulRoutedWeight)
1850 gate = gate * topk_weights.AsType<
float>()[m5];
1851 up = up * topk_weights.AsType<
float>()[m5];
1853 tensor_operation::element_wise::Gelu{}(gate, gate);
1854 c_thread_buf_fp32(cidx) = gate * up;
1869 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1870 if constexpr(MulRoutedWeight)
1872 c_thread_buf_fp32(cidx) =
1873 topk_weights.AsType<
float>()[m5] *
1874 c_thread_buf_fp32[cidx];
1884 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1887 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1888 static_cast<CShuffleDataType*
>(p_shared),
1889 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1892 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1896 Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{},
1905 Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{},
1910 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
1912 Sequence<0, 2, 4, 6, 7, 8>{},
1914 Sequence<1, 3, 5, 9>{}));
1918 const auto c_thread_mtx_on_block =
1919 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1921 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1922 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1924 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1930 const auto m_thread_data_on_block_idx =
1931 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1934 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1940 const auto n_thread_data_on_block_idx =
1941 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1945 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1948 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1949 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1951 Sequence<CShuffleMXdlPerWavePerShuffle /
MXdlPack,
1952 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
1961 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1966 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1969 m_thread_data_on_block_idx[
I1],
1970 n_thread_data_on_block_idx[
I1],
1971 m_thread_data_on_block_idx[
I2],
1972 n_thread_data_on_block_idx[
I2],
1973 m_thread_data_on_block_idx[
I3],
1974 m_thread_data_on_block_idx[
I4],
1975 m_thread_data_on_block_idx[
I5],
1976 n_thread_data_on_block_idx[
I3]),
1979 using EDataType = CDataType;
1984 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1990 return make_dynamic_buffer<AddressSpaceEnum::Global>(
1991 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1993 Number<NumDTensor>{});
1997 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1999 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2000 Number<NumDTensor>{}));
2004 tie(c_shuffle_block_buf),
2006 {
return ds_grid_buf[i]; },
2007 Number<NumDTensor>{}));
2010 const auto idx_c_ds_block_begin =
2018 Number<NumDTensor>{}));
2020 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2021 c_grid_desc_mblock_mperblock_nblock_nperblock;
2023 using CDEBlockTransferCluster =
2024 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2025 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2026 constexpr
index_t scatter_weight_idx = 3;
2027 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2031 decltype(c_ds_desc_refs),
2032 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2033 CElementwiseOperation,
2034 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>,
2038 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2040 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2041 CDEBlockTransferCluster,
2042 Sequence<0, 1, 2, 3>,
2043 Sequence<0, 1, 2, 3>,
2044 Sequence<0, 1, 2, 3>,
2047 CDEShuffleBlockTransferScalarPerVectors,
2059 idx_c_ds_block_begin,
2060 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2064 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2065 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2067 constexpr
auto sfc_c_vgpr =
2068 SpaceFillingCurve<Sequence<MXdlPerWave /
MXdlPack,
2078 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2079 Sequence<CShuffleMXdlPerWavePerShuffle /
MXdlPack,
2080 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2090 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2093 constexpr
auto sfc_cde_block =
2094 SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
2095 Sequence<0, 2, 1, 3>,
2097 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2099 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2101 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2102 constexpr
auto EMThreads =
2103 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2104 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2105 constexpr
auto ENThreads =
2106 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2107 static_for<0, num_access, 1>{}([&](
auto access_id) {
2109 StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
2111 auto dstidx = sfc_cde_block.GetIndex(access_id);
2113 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2114 static_for<0, EMRepeats, 1>{}([&](
auto m0) {
2115 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2116 IndexType token_offset = fused_token & 0xffffff;
2117 if constexpr(IsInputGemm)
2119 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2121 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
N;
2127 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2128 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2130 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2131 c_shuffle_block_buf);
2137 cde_block_copy_lds_and_global.Run(
2140 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2144 if constexpr(access_id < num_access - 1)
2146 constexpr
auto cde_lds_and_global_step =
2147 sfc_cde_block.GetForwardStep(access_id);
2150 static_for<0, NumDTensor, 1>{}([&](
auto i) {
2151 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2152 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2156 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2157 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2159 cde_lds_and_global_step);
2166 template <
bool HasMainKBlockLoop,
2170 const index_t* p_sorted_expert_ids,
2171 const index_t* p_max_token_id,
2172 const ADataType* p_a_grid,
2173 const AScaleDataType* p_a_scale_grid,
2174 const BDataType* p_b_grid,
2175 const BScaleDataType* p_b_scale_grid,
2177 CDataType* p_c_grid,
2181 AElementwiseOperation a_element_op,
2182 BElementwiseOperation b_element_op,
2183 CElementwiseOperation c_element_op)
2196 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2215 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2219 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2220 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
2221 if(expert_block_id * MPerBlock >= max_token_id)
2224 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2225 const auto block_mn = [&]() -> std::pair<int, int> {
2226 if constexpr(NSwizzle)
2228 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2230 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2231 const index_t expert_swizzle =
2232 ecnt > 0 ? ecnt : 1;
2233 const index_t bid_new = blockIdx.x - prefix_block;
2234 const index_t nid = __builtin_amdgcn_readfirstlane(
2235 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2237 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2242 return {blockIdx.x, blockIdx.y};
2246 const index_t block_n_id = block_mn.first;
2247 const index_t block_m_id = block_mn.second;
2249 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2252 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
2253 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
2254 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
2255 constexpr
auto AKThreads = AK0Threads * AK1Threads;
2256 constexpr
auto AMRepeats = MPerBlock / AMThreads;
2257 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2259 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
2263 const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2264 index_t token_offset = fused_token & 0xffffff;
2265 if constexpr(!IsInputGemm)
2267 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2269 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
2273 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
2274 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2275 problem.
N * (IsInputGemm ? 2 : 1) *
2279 const index_t n_block_data_idx_on_grid =
2280 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
2283 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2284 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2285 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2286 p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2289 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2290 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2291 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2292 p_b_scale_grid + (expert_id * expert_scale_stride) /
sizeof(BScaleDataType),
2293 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2308 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2309 ABlockTransferThreadClusterArrangeOrder,
2312 decltype(a_grid_desc_ak0_m_ak1),
2313 decltype(a_block_desc_ak0_m_ak1),
2314 ABlockTransferSrcAccessOrder,
2315 ABlockTransferSrcVectorDim,
2317 ABlockTransferSrcScalarPerVector,
2319 1>(a_grid_desc_ak0_m_ak1,
2321 a_block_desc_ak0_m_ak1,
2326 auto b_blockwise_copy =
2329 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2330 BBlockTransferThreadClusterArrangeOrder,
2333 decltype(b_grid_desc_bk0_n_bk1),
2334 decltype(b_block_desc_bk0_n_bk1),
2335 BBlockTransferSrcAccessOrder,
2336 BBlockTransferSrcVectorDim,
2338 BBlockTransferSrcScalarPerVector>(
2339 b_grid_desc_bk0_n_bk1,
2341 b_block_desc_bk0_n_bk1,
2346 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2348 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2349 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2351 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2352 bit_cast<BDataType*>(
static_cast<char*
>(p_shared_0) +
2353 a_block_space_size_aligned *
sizeof(ADataType)),
2354 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2356 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2357 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2359 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2360 bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2361 a_block_space_size_aligned *
sizeof(ADataType)),
2362 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2364 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2365 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2371 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2373 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2374 decltype(c_thread_buf) c_thread_buf_up;
2378 c_thread_buf.num_of_v_,
2379 c_thread_buf.s_per_v,
2383 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2384 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2388 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2389 const auto waveId_m = wave_idx[
I0];
2390 const auto waveId_n = wave_idx[
I1];
2392 auto thread_offset_shuffled =
2395 auto a_thread_offset_m = waveId_m;
2398 const index_t token_scale_pos = block_m_id * MPerBlock;
2399 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
2405 decltype(a_scale_grid_desc_am_ak),
2406 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2412 true>(a_scale_grid_desc_am_ak,
2418 auto b_thread_offset_n = waveId_n;
2423 decltype(b_scale_grid_desc_bn_ak),
2424 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2430 true>(b_scale_grid_desc_bn_ak,
2435 if constexpr(IsInputGemm)
2437 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2438 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2439 p_b_grid_up + expert_id * expert_stride,
2440 b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2444 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
2445 auto b_block_buf_up_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2446 bit_cast<BDataType*>(
static_cast<char*
>(p_shared_0) +
2447 a_block_space_size_aligned *
sizeof(ADataType) +
2448 b_block_space_size_aligned *
sizeof(BDataType)),
2449 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2450 auto b_block_buf_up_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2451 bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2452 a_block_space_size_aligned *
sizeof(ADataType) +
2453 b_block_space_size_aligned *
sizeof(BDataType)),
2454 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2456 auto b_block_bufs_up =
make_tuple(b_block_buf_up_ping, b_block_buf_up_pong);
2461 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2462 BBlockTransferThreadClusterArrangeOrder,
2465 decltype(b_grid_desc_bk0_n_bk1),
2466 decltype(b_block_desc_bk0_n_bk1),
2467 BBlockTransferSrcAccessOrder,
2468 BBlockTransferSrcVectorDim,
2470 BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
2472 b_block_desc_bk0_n_bk1,
2475 const BScaleDataType* p_b_scale_grid_up =
2476 p_b_scale_grid + expert_scale_stride / 2 /
sizeof(BScaleDataType);
2477 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2478 p_b_scale_grid_up + expert_id * expert_scale_stride /
sizeof(BScaleDataType),
2479 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2484 decltype(b_scale_grid_desc_bn_ak),
2485 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2492 b_scale_grid_desc_bn_ak,
2497 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2499 a_grid_desc_ak0_m_ak1,
2500 a_block_desc_ak0_m_ak1,
2504 a_block_slice_copy_step,
2506 b_grid_desc_bk0_n_bk1,
2507 b_block_desc_bk0_n_bk1,
2509 b_blockwise_copy_up,
2514 b_block_slice_copy_step,
2519 a_scale_grid_desc_am_ak,
2520 a_scale_thread_copy,
2523 b_scale_grid_desc_bn_ak,
2524 b_scale_thread_copy,
2525 b_scale_thread_copy_up,
2527 b_scale_grid_buf_up,
2528 num_k_block_main_loop);
2532 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2533 a_grid_desc_ak0_m_ak1,
2534 a_block_desc_ak0_m_ak1,
2538 a_block_slice_copy_step,
2539 b_grid_desc_bk0_n_bk1,
2540 b_block_desc_bk0_n_bk1,
2544 b_block_slice_copy_step,
2546 a_scale_grid_desc_am_ak,
2547 a_scale_thread_copy,
2549 b_scale_grid_desc_bn_ak,
2550 b_scale_thread_copy,
2552 num_k_block_main_loop);
2557 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2558 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2560 static_assert(CShuffleMXdlPerWavePerShuffle %
MXdlPack == 0 &&
2561 CShuffleNXdlPerWavePerShuffle %
NXdlPack == 0,
2564 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2565 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2568 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2569 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2573 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2574 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2576 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2577 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2578 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2579 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2580 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2581 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2582 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2583 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2584 constexpr
auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
2585 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
2589 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2590 static_assert(M5 == 4);
2600 const index_t m_pos = block_m_id * MPerBlock +
2601 m0 * M2 * M1 * M3 * M4 * M5 +
2602 m1 * M2 * M3 * M4 * M5 +
2603 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2604 if constexpr(MulRoutedWeight)
2607 *c_style_pointer_cast<const vector_type<float, M5>*>(
2608 p_ds_grid[
I2] + m_pos);
2612 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2613 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2616 if constexpr(IsInputGemm)
2618 if constexpr(ActivationOperation ==
2621 float gate = c_thread_buf[cidx];
2622 float up = c_thread_buf_up[cidx];
2623 if constexpr(MulRoutedWeight)
2625 gate = gate * topk_weights.AsType<
float>()[m5];
2626 up = up * topk_weights.AsType<
float>()[m5];
2629 c_thread_buf_fp32(cidx) = gate * up;
2633 float gate = c_thread_buf[cidx];
2634 float up = c_thread_buf_up[cidx];
2635 if constexpr(MulRoutedWeight)
2637 gate = gate * topk_weights.AsType<
float>()[m5];
2638 up = up * topk_weights.AsType<
float>()[m5];
2641 c_thread_buf_fp32(cidx) = gate * up;
2646 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2647 if constexpr(MulRoutedWeight)
2649 c_thread_buf_fp32(cidx) =
2650 topk_weights.AsType<
float>()[m5] *
2651 c_thread_buf_fp32[cidx];
2661 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2664 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2665 static_cast<CShuffleDataType*
>(p_shared_0),
2666 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2669 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2695 const auto c_thread_mtx_on_block =
2696 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2698 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2699 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2701 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2707 const auto m_thread_data_on_block_idx =
2708 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2711 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2717 const auto n_thread_data_on_block_idx =
2718 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2725 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2726 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2729 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2738 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2743 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2746 m_thread_data_on_block_idx[
I1],
2747 n_thread_data_on_block_idx[
I1],
2748 m_thread_data_on_block_idx[
I2],
2749 n_thread_data_on_block_idx[
I2],
2750 m_thread_data_on_block_idx[
I3],
2751 m_thread_data_on_block_idx[
I4],
2752 m_thread_data_on_block_idx[
I5],
2753 n_thread_data_on_block_idx[
I3]),
2756 using EDataType = CDataType;
2761 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2767 return make_dynamic_buffer<AddressSpaceEnum::Global>(
2768 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2774 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2776 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2781 tie(c_shuffle_block_buf),
2783 {
return ds_grid_buf[i]; },
2787 const auto idx_c_ds_block_begin =
2797 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2798 c_grid_desc_mblock_mperblock_nblock_nperblock;
2800 using CDEBlockTransferCluster =
2801 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2802 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2803 constexpr
index_t scatter_weight_idx = 3;
2808 decltype(c_ds_desc_refs),
2809 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2810 CElementwiseOperation,
2815 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2817 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2818 CDEBlockTransferCluster,
2824 CDEShuffleBlockTransferScalarPerVectors,
2836 idx_c_ds_block_begin,
2837 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2841 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2842 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2844 constexpr
auto sfc_c_vgpr =
2855 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2857 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2867 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2870 constexpr
auto sfc_cde_block =
2874 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2876 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2878 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2879 constexpr
auto EMThreads =
2880 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2881 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2882 constexpr
auto ENThreads =
2883 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2888 auto dstidx = sfc_cde_block.GetIndex(access_id);
2890 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2892 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2893 IndexType token_offset = fused_token & 0xffffff;
2894 if constexpr(IsInputGemm)
2896 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2898 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
N;
2904 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2905 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2907 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2908 c_shuffle_block_buf);
2914 cde_block_copy_lds_and_global.Run(
2917 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2921 if constexpr(access_id < num_access - 1)
2923 constexpr
auto cde_lds_and_global_step =
2924 sfc_cde_block.GetForwardStep(access_id);
2928 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2929 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2933 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2934 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2936 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
constexpr auto BlockGemmMXPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp:36
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:90
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: gridwise_moe_mx_gemm.hpp:721
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm.hpp:783
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm.hpp:789
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm.hpp:791
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm.hpp:782
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm.hpp:793
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm.hpp:786
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:787
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm.hpp:788
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm.hpp:722
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm.hpp:781
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm.hpp:792
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:785
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm.hpp:784
Definition: gridwise_moe_mx_gemm.hpp:649
index_t MBlock
Definition: gridwise_moe_mx_gemm.hpp:715
index_t NPadded
Definition: gridwise_moe_mx_gemm.hpp:710
index_t K
Definition: gridwise_moe_mx_gemm.hpp:701
index_t N
Definition: gridwise_moe_mx_gemm.hpp:700
index_t NumTokens
Definition: gridwise_moe_mx_gemm.hpp:697
index_t M
Definition: gridwise_moe_mx_gemm.hpp:699
index_t StrideA
Definition: gridwise_moe_mx_gemm.hpp:702
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm.hpp:705
index_t KRead
Definition: gridwise_moe_mx_gemm.hpp:711
index_t NBlock
Definition: gridwise_moe_mx_gemm.hpp:716
index_t StrideC
Definition: gridwise_moe_mx_gemm.hpp:707
index_t StrideB
Definition: gridwise_moe_mx_gemm.hpp:704
__host__ void Print() const
Definition: gridwise_moe_mx_gemm.hpp:685
index_t BK0
Definition: gridwise_moe_mx_gemm.hpp:714
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm.hpp:703
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm.hpp:706
index_t MPadded
Definition: gridwise_moe_mx_gemm.hpp:709
index_t KBatch
Definition: gridwise_moe_mx_gemm.hpp:708
index_t KPadded
Definition: gridwise_moe_mx_gemm.hpp:712
index_t TopK
Definition: gridwise_moe_mx_gemm.hpp:698
index_t AK0
Definition: gridwise_moe_mx_gemm.hpp:713
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm.hpp:650
Definition: gridwise_moe_mx_gemm.hpp:797
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm.hpp:798
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:852
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:854
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:851
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:853
Definition: gridwise_moe_mx_gemm.hpp:179
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm.hpp:248
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm.hpp:218
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm.hpp:233
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm.hpp:1179
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm.hpp:625
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:637
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:279
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm.hpp:197
static constexpr auto I7
Definition: gridwise_moe_mx_gemm.hpp:190
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm.hpp:219
remove_cvref_t< decltype(BlockGemmMXPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm.hpp:1138
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:258
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:285
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm.hpp:210
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm.hpp:978
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm.hpp:337
static constexpr auto I6
Definition: gridwise_moe_mx_gemm.hpp:189
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1364
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm.hpp:199
static constexpr auto I9
Definition: gridwise_moe_mx_gemm.hpp:192
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:291
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm.hpp:244
static constexpr auto I8
Definition: gridwise_moe_mx_gemm.hpp:191
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm.hpp:452
static constexpr auto I0
Definition: gridwise_moe_mx_gemm.hpp:183
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm.hpp:181
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm.hpp:561
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm.hpp:209
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm.hpp:580
static constexpr auto I3
Definition: gridwise_moe_mx_gemm.hpp:186
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm.hpp:1392
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm.hpp:1391
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:298
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:268
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm.hpp:206
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm.hpp:231
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:303
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm.hpp:857
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm.hpp:194
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm.hpp:2169
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm.hpp:180
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm.hpp:571
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm.hpp:246
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm.hpp:202
static constexpr auto I1
Definition: gridwise_moe_mx_gemm.hpp:184
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm.hpp:200
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm.hpp:204
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm.hpp:604
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:263
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm.hpp:1140
static constexpr auto I5
Definition: gridwise_moe_mx_gemm.hpp:188
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm.hpp:227
static constexpr auto I4
Definition: gridwise_moe_mx_gemm.hpp:187
static constexpr auto I2
Definition: gridwise_moe_mx_gemm.hpp:185
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm.hpp:313
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm.hpp:208
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm.hpp:203
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm.hpp:1095
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:1372
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:273
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm.hpp:198
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1357
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Definition: thread_group_tensor_slice_transfer_direct_load.hpp:55
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:26
Definition: data_type.hpp:42
Definition: integral_constant.hpp:20
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129