39 template <
typename GridwiseGemm,
40 bool HasMainKBlockLoop,
45 #if CK_USE_LAUNCH_BOUNDS
52 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
54 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
56 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
57 karg.p_sorted_token_ids,
58 karg.p_sorted_expert_ids,
60 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
61 karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
62 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
63 karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
77 template <
typename GridwiseGemm,
78 bool HasMainKBlockLoop,
83 #if CK_USE_LAUNCH_BOUNDS
90 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
91 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
93 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
95 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
96 karg.p_sorted_token_ids,
97 karg.p_sorted_expert_ids,
99 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
100 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
101 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
102 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
116 template <
typename ALayout,
121 typename AScaleDataType,
123 typename BScaleDataType,
124 typename AccDataType,
125 typename CShuffleDataType,
128 typename AElementwiseOperation,
129 typename BElementwiseOperation,
130 typename CElementwiseOperation,
143 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
144 typename ABlockTransferThreadClusterArrangeOrder,
145 typename ABlockTransferSrcAccessOrder,
146 index_t ABlockTransferSrcVectorDim,
147 index_t ABlockTransferSrcScalarPerVector,
148 index_t ABlockTransferDstScalarPerVector_AK1,
149 bool AThreadTransferSrcResetCoordinateAfterRun,
151 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
152 typename BBlockTransferThreadClusterArrangeOrder,
153 typename BBlockTransferSrcAccessOrder,
154 index_t BBlockTransferSrcVectorDim,
155 index_t BBlockTransferSrcScalarPerVector,
156 index_t BBlockTransferDstScalarPerVector_BK1,
157 bool BThreadTransferSrcResetCoordinateAfterRun,
159 index_t CShuffleMXdlPerWavePerShuffle,
160 index_t CShuffleNXdlPerWavePerShuffle,
161 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162 typename CDEShuffleBlockTransferScalarPerVectors,
165 index_t ActivationOperation = 0,
166 bool NSwizzle =
false,
167 bool IsInputGemm =
true,
168 bool MulRoutedWeight =
true,
170 typename ComputeTypeA = ADataType,
171 typename ComputeTypeB = BDataType>
189 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
233 return static_cast<const DDataType*
>(
nullptr);
246 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
247 const index_t gridy = NSwizzle ? 1 : mblock;
269 auto K_t = K_Batch * KPerBlock;
270 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
275 auto K_t = K_Batch * KPerBlock;
276 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
281 auto K_t = K_Batch * KPerBlock;
282 return (K + K_t - 1) / K_t * KPerBlock;
288 auto K_t = K_Batch * KReadVec;
289 return (K + K_t - 1) / K_t * KReadVec;
302 template <
index_t MNXdlPerWave,
306 typename TileDesc_K0_MN_K1>
332 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
334 const auto a_grid_desc_mraw_kraw = [&]() {
335 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
339 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
347 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
348 GemmSpec == GemmSpecialization::MNKPadding)
351 const auto a_grid_desc_m_k =
365 return a_grid_desc_ak0_m_ak1;
367 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
368 GemmSpec == GemmSpecialization::MNPadding)
372 a_grid_desc_mraw_kraw,
379 a_grid_desc_ak0_m_ak1,
387 a_grid_desc_permuted,
396 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
397 GemmSpec == GemmSpecialization::NKPadding)
401 a_grid_desc_mraw_kraw,
413 return a_grid_desc_ak0_m_ak1;
419 a_grid_desc_mraw_kraw,
426 a_grid_desc_ak0_m_ak1,
434 a_grid_desc_permuted,
449 const auto b_grid_desc_nraw_kraw = [&]() {
463 GemmSpec != GemmSpecialization::Default),
464 "pk_i4_t does not support padding");
466 (GemmSpec != GemmSpecialization::Default &&
467 GemmSpec != GemmSpecialization::MPadding)),
468 "f4x2_pk_t does not support K padding");
470 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
471 GemmSpec == GemmSpecialization::MNKPadding)
474 const auto b_grid_desc_n_k =
488 return b_grid_desc_bk0_n_bk1;
490 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
491 GemmSpec == GemmSpecialization::MNPadding)
495 b_grid_desc_nraw_kraw,
501 return b_grid_desc_bk0_n_bk1;
503 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
504 GemmSpec == GemmSpecialization::MKPadding)
508 b_grid_desc_nraw_kraw,
520 return b_grid_desc_bk0_n_bk1;
526 b_grid_desc_nraw_kraw,
533 b_grid_desc_bk0_n_bk1,
541 b_grid_desc_permuted,
553 template <
typename ABlockDesc_AK0_M_AK1>
554 __host__ __device__
static constexpr
auto
557 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
559 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl>(
560 ABlockDesc_AK0_M_AK1{});
563 template <
typename BBlockDesc_BK0_N_BK1>
564 __host__ __device__
static constexpr
auto
567 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
569 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl>(
570 BBlockDesc_BK0_N_BK1{});
573 template <
typename ELayout>
575 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
577 const auto c_grid_desc_mraw_nraw = [&]() {
596 template <
typename DLayout>
597 __host__ __device__
static auto
600 const auto c_grid_desc_mraw_nraw = [&]() {
625 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
630 template <
typename DsGr
idDesc>
632 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
637 ds_grid_desc_m_n[i], MBlock, NBlock);
653 std::array<index_t, NumDTensor> StrideDs_,
681 std::cout <<
"problem {" <<
"NumTokens:" <<
NumTokens <<
", " <<
"TopK:" <<
TopK <<
", "
682 <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
686 <<
", " <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", "
687 <<
"AK0:" <<
AK0 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock
688 <<
", " <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
717 const index_t* p_sorted_expert_ids_,
718 const index_t* p_max_token_id_,
719 const ADataType* p_a_grid_,
720 const AScaleDataType* p_a_scale_grid_,
721 const BDataType* p_b_grid_,
722 const BScaleDataType* p_b_scale_grid_,
723 std::array<const void*, NumDTensor> p_ds_grid_,
724 CDataType* p_c_grid_,
734 std::array<index_t, NumDTensor> StrideDs_,
737 AElementwiseOperation a_element_op_,
738 BElementwiseOperation b_element_op_,
739 CElementwiseOperation c_element_op_)
771 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
794 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
798 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
803 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
807 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
814 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
818 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
825 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
830 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
835 if(k_id < karg.
KBatch - 1)
853 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
854 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
855 constexpr
index_t WaveSize = BlockSize / (MWave * NWave);
869 constexpr
auto a_lds_block_desc =
881 return a_lds_block_desc_permuted;
888 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
889 constexpr
auto M1 = MPerBlock / M0;
891 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
892 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
893 constexpr
auto KThreadRead = WaveSize / MPerXdl;
894 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
896 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
898 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
899 constexpr
auto KThreadReadPerm =
900 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
901 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
905 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
907 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
909 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
915 Number<kfold * M0 / mpair>{},
934 a_lds_block_desc_permuted,
956 a_lds_block_desc_unmerged,
959 Number<KThreadWrite / kfold / KThreadReadPerm>{},
968 return a_lds_block_desc_ak0_m_ak1;
974 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
975 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
976 constexpr
index_t WaveSize = BlockSize / (MWave * NWave);
989 constexpr
auto b_lds_block_desc =
1001 return b_lds_block_desc_permuted;
1005 constexpr
auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
1006 constexpr
auto N1 = NPerBlock / N0;
1008 constexpr
auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
1009 constexpr
auto K0PerThreadWrite =
BK0Number / KThreadWrite;
1010 constexpr
auto KThreadRead = WaveSize / NPerXdl;
1011 constexpr
auto K0PerThreadRead =
BK0Number / KThreadRead;
1013 constexpr
auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
1015 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
1016 constexpr
auto KThreadReadPerm =
1017 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1018 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1022 constexpr
auto npair = (
BK1Number * NPerXdl *
sizeof(BDataType) > 128)
1024 : ((128 / (
BK1Number * NPerXdl *
sizeof(BDataType))) > N0
1026 : 128 / (
BK1Number * NPerXdl *
sizeof(BDataType)));
1032 Number<kfold * N0 / npair>{},
1051 b_lds_block_desc_permuted,
1073 b_lds_block_desc_unmerged,
1076 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1085 return b_lds_block_desc_bk0_n_bk1;
1091 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1092 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1094 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1101 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1122 ABlockTransferSrcScalarPerVector,
1123 BBlockTransferSrcScalarPerVector,
1144 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1147 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1150 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1153 constexpr
auto c_block_size =
1154 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1156 if constexpr(IsInputGemm)
1158 return math::max(a_block_space_size_aligned *
sizeof(ADataType) +
1159 b_block_space_size_aligned *
sizeof(BDataType) * 2,
1160 c_block_size *
sizeof(CShuffleDataType));
1164 return math::max((a_block_space_size_aligned *
sizeof(ADataType) +
1165 b_block_space_size_aligned *
sizeof(BDataType)),
1166 c_block_size *
sizeof(CShuffleDataType));
1173 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1174 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1175 "Invalid tuning param!");
1177 static_assert(KPerBlock % (ScaleBlockSize /
BPackedSize) == 0,
1178 "KPerBlock should be multiple of ScaleBlockSize");
1186 if(!(karg.
M % MPerBlock == 0))
1190 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M <<
" "
1191 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1204 if(!(karg.
N % NPerBlock == 0))
1208 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N <<
" "
1209 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1221 auto K_t = karg.
KBatch * KPerBlock;
1222 if(!(karg.
K % K_t == 0))
1226 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1227 << karg.
K <<
" " << __FILE__ <<
":" << __LINE__
1228 <<
", in function: " << __func__ << std::endl;
1236 auto K_t = karg.
KBatch * KReadVec;
1238 if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
1246 if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
1250 std::cout <<
"Arg K (" << karg.
K
1251 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1252 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1253 << __LINE__ <<
", in function: " << __func__ << std::endl;
1260 if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
1264 std::cout <<
"Arg M (" << karg.
M
1265 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1266 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1267 << __LINE__ <<
", in function: " << __func__ << std::endl;
1275 if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
1279 std::cout <<
"Arg N (" << karg.
N
1280 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1281 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1282 << __LINE__ <<
", in function: " << __func__ << std::endl;
1289 if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
1293 std::cout <<
"Arg K (" << karg.
K
1294 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1295 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1296 << __LINE__ <<
", in function: " << __func__ << std::endl;
1308 std::cout <<
"Arg N (" << karg.
N
1309 <<
") value is not a multiple of "
1310 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1312 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1324 std::cout <<
"Arg M (" << karg.
M
1325 <<
") value is not a multiple of "
1326 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1328 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1338 const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
1340 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1351 const index_t num_loop = K / KPerBlock;
1353 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1358 const index_t num_loop = K / KPerBlock;
1360 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1363 template <
typename CGr
idDesc>
1365 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1374 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1386 "A scale pack data type too large!");
1388 "B scale pack data type too large!");
1390 static_assert(is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
1391 is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>,
1392 "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!");
1395 template <
bool HasMainKBlockLoop,
1398 __device__
static void Run(
const index_t* p_sorted_token_ids,
1399 const index_t* p_sorted_expert_ids,
1400 const index_t* p_max_token_id,
1401 const ADataType* p_a_grid,
1402 const AScaleDataType* p_a_scale_grid,
1403 const BDataType* p_b_grid,
1404 const BScaleDataType* p_b_scale_grid,
1406 CDataType* p_c_grid,
1409 AElementwiseOperation a_element_op,
1410 BElementwiseOperation b_element_op,
1411 CElementwiseOperation c_element_op)
1424 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1443 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1447 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1448 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1449 if(expert_block_id * MPerBlock >= max_token_id)
1452 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1454 const auto block_mn = [&]() -> std::pair<int, int> {
1455 if constexpr(NSwizzle)
1457 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1459 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1460 const index_t expert_swizzle =
1461 ecnt > 0 ? ecnt : 1;
1462 const index_t bid_new = blockIdx.x - prefix_block;
1463 const index_t nid = __builtin_amdgcn_readfirstlane(
1464 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1466 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1471 return {blockIdx.x, blockIdx.y};
1475 const index_t block_n_id = block_mn.first;
1476 const index_t block_m_id = block_mn.second;
1478 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1481 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1482 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1483 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1484 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1485 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1486 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1488 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
1490 StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
1491 static_for<0, AMRepeats, 1>{}([&](
auto m0) {
1492 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1493 index_t token_offset = fused_token & 0xffffff;
1494 if constexpr(!IsInputGemm)
1496 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1498 gather_offsets(m0) =
static_cast<IndexType
>(token_offset);
1502 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
1503 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1504 problem.
N * (IsInputGemm ? 2 : 1) *
1508 const index_t n_block_data_idx_on_grid =
1509 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1512 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1513 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1514 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1515 p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1518 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1519 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1520 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1521 p_b_scale_grid + (expert_id * expert_scale_stride) /
sizeof(BScaleDataType),
1522 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1534 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
1536 Sequence<AK0Number, MPerBlock, AK1Number>,
1537 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1538 ABlockTransferThreadClusterArrangeOrder,
1541 decltype(a_grid_desc_ak0_m_ak1),
1542 decltype(a_block_desc_ak0_m_ak1),
1543 ABlockTransferSrcAccessOrder,
1544 ABlockTransferSrcVectorDim,
1546 ABlockTransferSrcScalarPerVector,
1548 1>(a_grid_desc_ak0_m_ak1,
1550 a_block_desc_ak0_m_ak1,
1555 auto b_blockwise_copy =
1557 Sequence<BK0Number, NPerBlock, BK1Number>,
1558 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1559 BBlockTransferThreadClusterArrangeOrder,
1562 decltype(b_grid_desc_bk0_n_bk1),
1563 decltype(b_block_desc_bk0_n_bk1),
1564 BBlockTransferSrcAccessOrder,
1565 BBlockTransferSrcVectorDim,
1567 BBlockTransferSrcScalarPerVector>(
1568 b_grid_desc_bk0_n_bk1,
1570 b_block_desc_bk0_n_bk1,
1575 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1578 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1579 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1581 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1582 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) +
1583 a_block_space_size_aligned *
sizeof(ADataType)),
1584 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1590 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1592 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1593 decltype(c_thread_buf) c_thread_buf_up;
1597 c_thread_buf.num_of_v_,
1598 c_thread_buf.s_per_v,
1602 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1603 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1607 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1608 const auto waveId_m = wave_idx[
I0];
1609 const auto waveId_n = wave_idx[
I1];
1611 auto thread_offset_shuffled =
1614 auto a_thread_offset_m = waveId_m;
1616 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1619 decltype(a_scale_grid_desc_am_ak),
1620 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1626 true>(a_scale_grid_desc_am_ak,
1632 auto b_thread_offset_n = waveId_n;
1634 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1637 decltype(b_scale_grid_desc_bn_ak),
1638 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1644 true>(b_scale_grid_desc_bn_ak,
1649 if constexpr(IsInputGemm)
1652 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1653 auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1654 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) +
1655 a_block_space_size_aligned *
sizeof(ADataType) +
1656 b_block_space_size_aligned *
sizeof(BDataType)),
1657 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1659 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1660 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1661 p_b_grid_up + expert_id * expert_stride,
1662 b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1664 auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
1666 Sequence<BK0Number, NPerBlock, BK1Number>,
1667 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1668 BBlockTransferThreadClusterArrangeOrder,
1671 decltype(b_grid_desc_bk0_n_bk1),
1672 decltype(b_block_desc_bk0_n_bk1),
1673 BBlockTransferSrcAccessOrder,
1674 BBlockTransferSrcVectorDim,
1676 BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
1678 b_block_desc_bk0_n_bk1,
1681 const BScaleDataType* p_b_scale_grid_up =
1682 p_b_scale_grid + expert_scale_stride / 2 /
sizeof(BScaleDataType);
1683 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1684 p_b_scale_grid_up + expert_id * expert_scale_stride /
sizeof(BScaleDataType),
1685 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1687 auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1690 decltype(b_scale_grid_desc_bn_ak),
1691 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1698 b_scale_grid_desc_bn_ak,
1703 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1705 a_grid_desc_ak0_m_ak1,
1706 a_block_desc_ak0_m_ak1,
1710 a_block_slice_copy_step,
1712 b_grid_desc_bk0_n_bk1,
1713 b_block_desc_bk0_n_bk1,
1715 b_blockwise_copy_up,
1720 b_block_slice_copy_step,
1725 a_scale_grid_desc_am_ak,
1726 a_scale_thread_copy,
1729 b_scale_grid_desc_bn_ak,
1730 b_scale_thread_copy,
1731 b_scale_thread_copy_up,
1733 b_scale_grid_buf_up,
1734 num_k_block_main_loop);
1738 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1739 a_grid_desc_ak0_m_ak1,
1740 a_block_desc_ak0_m_ak1,
1744 a_block_slice_copy_step,
1745 b_grid_desc_bk0_n_bk1,
1746 b_block_desc_bk0_n_bk1,
1750 b_block_slice_copy_step,
1752 a_scale_grid_desc_am_ak,
1753 a_scale_thread_copy,
1755 b_scale_grid_desc_bn_ak,
1756 b_scale_thread_copy,
1758 num_k_block_main_loop);
1763 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1764 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1766 static_assert(CShuffleMXdlPerWavePerShuffle %
MXdlPack == 0 &&
1767 CShuffleNXdlPerWavePerShuffle %
NXdlPack == 0,
1770 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1771 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1774 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1775 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1779 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1780 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1782 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1783 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1784 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1785 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1786 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1787 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1788 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1789 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1790 constexpr
auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
1791 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
1794 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1795 static_assert(M5 == 4);
1799 vector_type<float, 4> topk_weights;
1800 static_for<0, NXdlPerWave /
NXdlPack, 1>{}([&](
auto n0) {
1801 static_for<0, NXdlPack, 1>{}([&](
auto inxdl) {
1802 static_for<0, MXdlPerWave /
MXdlPack, 1>{}([&](
auto m0) {
1803 static_for<0, MXdlPack, 1>{}([&](
auto imxdl) {
1804 static_for<0, M3, 1>{}([&](
auto m3) {
1805 const index_t m_pos = block_m_id * MPerBlock +
1806 m0 * M2 * M1 * M3 * M4 * M5 +
1807 m1 * M2 * M3 * M4 * M5 +
1808 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1809 if constexpr(MulRoutedWeight)
1812 *c_style_pointer_cast<const vector_type<float, M5>*>(
1813 p_ds_grid[
I2] + m_pos);
1815 static_for<0, M5, 1>{}([&](
auto m5) {
1817 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1818 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1819 constexpr
auto cidx = Number<c_offset>{};
1821 if constexpr(IsInputGemm)
1823 if constexpr(ActivationOperation ==
1826 float gate = c_thread_buf[cidx];
1827 float up = c_thread_buf_up[cidx];
1828 if constexpr(MulRoutedWeight)
1830 gate = gate * topk_weights.AsType<
float>()[m5];
1831 up = up * topk_weights.AsType<
float>()[m5];
1833 tensor_operation::element_wise::Silu{}(gate, gate);
1834 c_thread_buf_fp32(cidx) = gate * up;
1838 float gate = c_thread_buf[cidx];
1839 float up = c_thread_buf_up[cidx];
1840 if constexpr(MulRoutedWeight)
1842 gate = gate * topk_weights.AsType<
float>()[m5];
1843 up = up * topk_weights.AsType<
float>()[m5];
1845 tensor_operation::element_wise::Gelu{}(gate, gate);
1846 c_thread_buf_fp32(cidx) = gate * up;
1861 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1862 if constexpr(MulRoutedWeight)
1864 c_thread_buf_fp32(cidx) =
1865 topk_weights.AsType<
float>()[m5] *
1866 c_thread_buf_fp32[cidx];
1876 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1879 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1880 static_cast<CShuffleDataType*
>(p_shared),
1881 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1884 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1888 Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{},
1897 Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{},
1902 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
1904 Sequence<0, 2, 4, 6, 7, 8>{},
1906 Sequence<1, 3, 5, 9>{}));
1910 const auto c_thread_mtx_on_block =
1911 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1913 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1914 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1916 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1922 const auto m_thread_data_on_block_idx =
1923 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1926 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1932 const auto n_thread_data_on_block_idx =
1933 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1937 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1940 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1941 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1943 Sequence<CShuffleMXdlPerWavePerShuffle /
MXdlPack,
1944 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
1953 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1958 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1961 m_thread_data_on_block_idx[
I1],
1962 n_thread_data_on_block_idx[
I1],
1963 m_thread_data_on_block_idx[
I2],
1964 n_thread_data_on_block_idx[
I2],
1965 m_thread_data_on_block_idx[
I3],
1966 m_thread_data_on_block_idx[
I4],
1967 m_thread_data_on_block_idx[
I5],
1968 n_thread_data_on_block_idx[
I3]),
1971 using EDataType = CDataType;
1976 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1982 return make_dynamic_buffer<AddressSpaceEnum::Global>(
1983 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1985 Number<NumDTensor>{});
1989 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1991 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1992 Number<NumDTensor>{}));
1996 tie(c_shuffle_block_buf),
1998 {
return ds_grid_buf[i]; },
1999 Number<NumDTensor>{}));
2002 const auto idx_c_ds_block_begin =
2010 Number<NumDTensor>{}));
2012 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2013 c_grid_desc_mblock_mperblock_nblock_nperblock;
2015 using CDEBlockTransferCluster =
2016 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2017 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2018 constexpr
index_t scatter_weight_idx = 3;
2019 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2023 decltype(c_ds_desc_refs),
2024 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2025 CElementwiseOperation,
2026 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>,
2030 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2032 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2033 CDEBlockTransferCluster,
2034 Sequence<0, 1, 2, 3>,
2035 Sequence<0, 1, 2, 3>,
2036 Sequence<0, 1, 2, 3>,
2039 CDEShuffleBlockTransferScalarPerVectors,
2051 idx_c_ds_block_begin,
2052 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2056 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2057 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2059 constexpr
auto sfc_c_vgpr =
2060 SpaceFillingCurve<Sequence<MXdlPerWave /
MXdlPack,
2070 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2071 Sequence<CShuffleMXdlPerWavePerShuffle /
MXdlPack,
2072 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2082 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2085 constexpr
auto sfc_cde_block =
2086 SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
2087 Sequence<0, 2, 1, 3>,
2089 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2091 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2093 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2094 constexpr
auto EMThreads =
2095 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2096 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2097 constexpr
auto ENThreads =
2098 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2099 static_for<0, num_access, 1>{}([&](
auto access_id) {
2101 StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
2103 auto dstidx = sfc_cde_block.GetIndex(access_id);
2105 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2106 static_for<0, EMRepeats, 1>{}([&](
auto m0) {
2107 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2108 IndexType token_offset = fused_token & 0xffffff;
2109 if constexpr(IsInputGemm)
2111 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2113 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
N;
2119 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2120 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2122 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2123 c_shuffle_block_buf);
2129 cde_block_copy_lds_and_global.Run(
2132 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2136 if constexpr(access_id < num_access - 1)
2138 constexpr
auto cde_lds_and_global_step =
2139 sfc_cde_block.GetForwardStep(access_id);
2142 static_for<0, NumDTensor, 1>{}([&](
auto i) {
2143 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2144 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2148 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2149 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2151 cde_lds_and_global_step);
2158 template <
bool HasMainKBlockLoop,
2162 const index_t* p_sorted_expert_ids,
2163 const index_t* p_max_token_id,
2164 const ADataType* p_a_grid,
2165 const AScaleDataType* p_a_scale_grid,
2166 const BDataType* p_b_grid,
2167 const BScaleDataType* p_b_scale_grid,
2169 CDataType* p_c_grid,
2173 AElementwiseOperation a_element_op,
2174 BElementwiseOperation b_element_op,
2175 CElementwiseOperation c_element_op)
2188 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2207 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2211 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2212 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
2213 if(expert_block_id * MPerBlock >= max_token_id)
2216 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2217 const auto block_mn = [&]() -> std::pair<int, int> {
2218 if constexpr(NSwizzle)
2220 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2222 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2223 const index_t expert_swizzle =
2224 ecnt > 0 ? ecnt : 1;
2225 const index_t bid_new = blockIdx.x - prefix_block;
2226 const index_t nid = __builtin_amdgcn_readfirstlane(
2227 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2229 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2234 return {blockIdx.x, blockIdx.y};
2238 const index_t block_n_id = block_mn.first;
2239 const index_t block_m_id = block_mn.second;
2241 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2244 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
2245 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
2246 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
2247 constexpr
auto AKThreads = AK0Threads * AK1Threads;
2248 constexpr
auto AMRepeats = MPerBlock / AMThreads;
2249 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2251 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
2255 const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2256 index_t token_offset = fused_token & 0xffffff;
2257 if constexpr(!IsInputGemm)
2259 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2261 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
2265 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
2266 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2267 problem.
N * (IsInputGemm ? 2 : 1) *
2271 const index_t n_block_data_idx_on_grid =
2272 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
2275 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2276 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2277 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2278 p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2281 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2282 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2283 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2284 p_b_scale_grid + (expert_id * expert_scale_stride) /
sizeof(BScaleDataType),
2285 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2300 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2301 ABlockTransferThreadClusterArrangeOrder,
2304 decltype(a_grid_desc_ak0_m_ak1),
2305 decltype(a_block_desc_ak0_m_ak1),
2306 ABlockTransferSrcAccessOrder,
2307 ABlockTransferSrcVectorDim,
2309 ABlockTransferSrcScalarPerVector,
2311 1>(a_grid_desc_ak0_m_ak1,
2313 a_block_desc_ak0_m_ak1,
2318 auto b_blockwise_copy =
2321 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2322 BBlockTransferThreadClusterArrangeOrder,
2325 decltype(b_grid_desc_bk0_n_bk1),
2326 decltype(b_block_desc_bk0_n_bk1),
2327 BBlockTransferSrcAccessOrder,
2328 BBlockTransferSrcVectorDim,
2330 BBlockTransferSrcScalarPerVector>(
2331 b_grid_desc_bk0_n_bk1,
2333 b_block_desc_bk0_n_bk1,
2338 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2340 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2341 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2343 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2344 bit_cast<BDataType*>(
static_cast<char*
>(p_shared_0) +
2345 a_block_space_size_aligned *
sizeof(ADataType)),
2346 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2348 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2349 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2351 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2352 bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2353 a_block_space_size_aligned *
sizeof(ADataType)),
2354 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2356 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2357 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2363 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2365 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2366 decltype(c_thread_buf) c_thread_buf_up;
2370 c_thread_buf.num_of_v_,
2371 c_thread_buf.s_per_v,
2375 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2376 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2380 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2381 const auto waveId_m = wave_idx[
I0];
2382 const auto waveId_n = wave_idx[
I1];
2384 auto thread_offset_shuffled =
2387 auto a_thread_offset_m = waveId_m;
2390 const index_t token_scale_pos = block_m_id * MPerBlock;
2391 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
2397 decltype(a_scale_grid_desc_am_ak),
2398 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2404 true>(a_scale_grid_desc_am_ak,
2410 auto b_thread_offset_n = waveId_n;
2415 decltype(b_scale_grid_desc_bn_ak),
2416 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2422 true>(b_scale_grid_desc_bn_ak,
2427 if constexpr(IsInputGemm)
2429 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2430 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2431 p_b_grid_up + expert_id * expert_stride,
2432 b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2436 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
2437 auto b_block_buf_up_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2438 bit_cast<BDataType*>(
static_cast<char*
>(p_shared_0) +
2439 a_block_space_size_aligned *
sizeof(ADataType) +
2440 b_block_space_size_aligned *
sizeof(BDataType)),
2441 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2442 auto b_block_buf_up_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2443 bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2444 a_block_space_size_aligned *
sizeof(ADataType) +
2445 b_block_space_size_aligned *
sizeof(BDataType)),
2446 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2448 auto b_block_bufs_up =
make_tuple(b_block_buf_up_ping, b_block_buf_up_pong);
2453 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2454 BBlockTransferThreadClusterArrangeOrder,
2457 decltype(b_grid_desc_bk0_n_bk1),
2458 decltype(b_block_desc_bk0_n_bk1),
2459 BBlockTransferSrcAccessOrder,
2460 BBlockTransferSrcVectorDim,
2462 BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
2464 b_block_desc_bk0_n_bk1,
2467 const BScaleDataType* p_b_scale_grid_up =
2468 p_b_scale_grid + expert_scale_stride / 2 /
sizeof(BScaleDataType);
2469 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2470 p_b_scale_grid_up + expert_id * expert_scale_stride /
sizeof(BScaleDataType),
2471 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2476 decltype(b_scale_grid_desc_bn_ak),
2477 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2484 b_scale_grid_desc_bn_ak,
2489 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2491 a_grid_desc_ak0_m_ak1,
2492 a_block_desc_ak0_m_ak1,
2496 a_block_slice_copy_step,
2498 b_grid_desc_bk0_n_bk1,
2499 b_block_desc_bk0_n_bk1,
2501 b_blockwise_copy_up,
2506 b_block_slice_copy_step,
2511 a_scale_grid_desc_am_ak,
2512 a_scale_thread_copy,
2515 b_scale_grid_desc_bn_ak,
2516 b_scale_thread_copy,
2517 b_scale_thread_copy_up,
2519 b_scale_grid_buf_up,
2520 num_k_block_main_loop);
2524 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2525 a_grid_desc_ak0_m_ak1,
2526 a_block_desc_ak0_m_ak1,
2530 a_block_slice_copy_step,
2531 b_grid_desc_bk0_n_bk1,
2532 b_block_desc_bk0_n_bk1,
2536 b_block_slice_copy_step,
2538 a_scale_grid_desc_am_ak,
2539 a_scale_thread_copy,
2541 b_scale_grid_desc_bn_ak,
2542 b_scale_thread_copy,
2544 num_k_block_main_loop);
2549 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2550 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2552 static_assert(CShuffleMXdlPerWavePerShuffle %
MXdlPack == 0 &&
2553 CShuffleNXdlPerWavePerShuffle %
NXdlPack == 0,
2556 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2557 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2560 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2561 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2565 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2566 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2568 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2569 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2570 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2571 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2572 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2573 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2574 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2575 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2576 constexpr
auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
2577 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
2581 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2582 static_assert(M5 == 4);
2592 const index_t m_pos = block_m_id * MPerBlock +
2593 m0 * M2 * M1 * M3 * M4 * M5 +
2594 m1 * M2 * M3 * M4 * M5 +
2595 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2596 if constexpr(MulRoutedWeight)
2599 *c_style_pointer_cast<const vector_type<float, M5>*>(
2600 p_ds_grid[
I2] + m_pos);
2604 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2605 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2608 if constexpr(IsInputGemm)
2610 if constexpr(ActivationOperation ==
2613 float gate = c_thread_buf[cidx];
2614 float up = c_thread_buf_up[cidx];
2615 if constexpr(MulRoutedWeight)
2617 gate = gate * topk_weights.AsType<
float>()[m5];
2618 up = up * topk_weights.AsType<
float>()[m5];
2621 c_thread_buf_fp32(cidx) = gate * up;
2625 float gate = c_thread_buf[cidx];
2626 float up = c_thread_buf_up[cidx];
2627 if constexpr(MulRoutedWeight)
2629 gate = gate * topk_weights.AsType<
float>()[m5];
2630 up = up * topk_weights.AsType<
float>()[m5];
2633 c_thread_buf_fp32(cidx) = gate * up;
2638 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2639 if constexpr(MulRoutedWeight)
2641 c_thread_buf_fp32(cidx) =
2642 topk_weights.AsType<
float>()[m5] *
2643 c_thread_buf_fp32[cidx];
2653 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2656 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2657 static_cast<CShuffleDataType*
>(p_shared_0),
2658 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2661 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2687 const auto c_thread_mtx_on_block =
2688 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2690 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2691 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2693 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2699 const auto m_thread_data_on_block_idx =
2700 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2703 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2709 const auto n_thread_data_on_block_idx =
2710 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2717 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2718 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2721 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2730 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2735 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2738 m_thread_data_on_block_idx[
I1],
2739 n_thread_data_on_block_idx[
I1],
2740 m_thread_data_on_block_idx[
I2],
2741 n_thread_data_on_block_idx[
I2],
2742 m_thread_data_on_block_idx[
I3],
2743 m_thread_data_on_block_idx[
I4],
2744 m_thread_data_on_block_idx[
I5],
2745 n_thread_data_on_block_idx[
I3]),
2748 using EDataType = CDataType;
2753 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2759 return make_dynamic_buffer<AddressSpaceEnum::Global>(
2760 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2766 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2768 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2773 tie(c_shuffle_block_buf),
2775 {
return ds_grid_buf[i]; },
2779 const auto idx_c_ds_block_begin =
2789 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2790 c_grid_desc_mblock_mperblock_nblock_nperblock;
2792 using CDEBlockTransferCluster =
2793 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2794 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2795 constexpr
index_t scatter_weight_idx = 3;
2800 decltype(c_ds_desc_refs),
2801 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2802 CElementwiseOperation,
2807 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2809 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2810 CDEBlockTransferCluster,
2816 CDEShuffleBlockTransferScalarPerVectors,
2828 idx_c_ds_block_begin,
2829 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2833 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2834 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2836 constexpr
auto sfc_c_vgpr =
2847 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2849 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2859 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2862 constexpr
auto sfc_cde_block =
2866 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2868 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2870 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2871 constexpr
auto EMThreads =
2872 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2873 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2874 constexpr
auto ENThreads =
2875 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2880 auto dstidx = sfc_cde_block.GetIndex(access_id);
2882 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2884 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2885 IndexType token_offset = fused_token & 0xffffff;
2886 if constexpr(IsInputGemm)
2888 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2890 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
N;
2896 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2897 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2899 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2900 c_shuffle_block_buf);
2906 cde_block_copy_lds_and_global.Run(
2909 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2913 if constexpr(access_id < num_access - 1)
2915 constexpr
auto cde_lds_and_global_step =
2916 sfc_cde_block.GetForwardStep(access_id);
2920 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2921 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2925 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2926 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2928 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:56
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
constexpr auto BlockGemmMXPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp:36
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: gridwise_moe_mx_gemm.hpp:715
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm.hpp:777
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm.hpp:783
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm.hpp:785
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm.hpp:776
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm.hpp:787
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm.hpp:780
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:781
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm.hpp:782
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm.hpp:716
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm.hpp:775
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm.hpp:786
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:779
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm.hpp:778
Definition: gridwise_moe_mx_gemm.hpp:643
index_t MBlock
Definition: gridwise_moe_mx_gemm.hpp:709
index_t NPadded
Definition: gridwise_moe_mx_gemm.hpp:704
index_t K
Definition: gridwise_moe_mx_gemm.hpp:695
index_t N
Definition: gridwise_moe_mx_gemm.hpp:694
index_t NumTokens
Definition: gridwise_moe_mx_gemm.hpp:691
index_t M
Definition: gridwise_moe_mx_gemm.hpp:693
index_t StrideA
Definition: gridwise_moe_mx_gemm.hpp:696
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm.hpp:699
index_t KRead
Definition: gridwise_moe_mx_gemm.hpp:705
index_t NBlock
Definition: gridwise_moe_mx_gemm.hpp:710
index_t StrideC
Definition: gridwise_moe_mx_gemm.hpp:701
index_t StrideB
Definition: gridwise_moe_mx_gemm.hpp:698
__host__ void Print() const
Definition: gridwise_moe_mx_gemm.hpp:679
index_t BK0
Definition: gridwise_moe_mx_gemm.hpp:708
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm.hpp:697
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm.hpp:700
index_t MPadded
Definition: gridwise_moe_mx_gemm.hpp:703
index_t KBatch
Definition: gridwise_moe_mx_gemm.hpp:702
index_t KPadded
Definition: gridwise_moe_mx_gemm.hpp:706
index_t TopK
Definition: gridwise_moe_mx_gemm.hpp:692
index_t AK0
Definition: gridwise_moe_mx_gemm.hpp:707
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm.hpp:644
Definition: gridwise_moe_mx_gemm.hpp:791
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm.hpp:792
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:846
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:848
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:845
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:847
Definition: gridwise_moe_mx_gemm.hpp:173
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm.hpp:242
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm.hpp:212
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm.hpp:227
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm.hpp:1171
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm.hpp:619
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:631
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:273
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm.hpp:191
static constexpr auto I7
Definition: gridwise_moe_mx_gemm.hpp:184
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm.hpp:213
remove_cvref_t< decltype(BlockGemmMXPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm.hpp:1132
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:252
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:279
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm.hpp:204
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm.hpp:972
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm.hpp:331
static constexpr auto I6
Definition: gridwise_moe_mx_gemm.hpp:183
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1356
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm.hpp:193
static constexpr auto I9
Definition: gridwise_moe_mx_gemm.hpp:186
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:285
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm.hpp:238
static constexpr auto I8
Definition: gridwise_moe_mx_gemm.hpp:185
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm.hpp:446
static constexpr auto I0
Definition: gridwise_moe_mx_gemm.hpp:177
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm.hpp:175
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm.hpp:555
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm.hpp:203
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm.hpp:574
static constexpr auto I3
Definition: gridwise_moe_mx_gemm.hpp:180
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm.hpp:1384
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm.hpp:1383
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:292
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:262
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm.hpp:200
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm.hpp:225
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:297
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm.hpp:851
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm.hpp:188
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm.hpp:2161
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm.hpp:174
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm.hpp:565
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm.hpp:240
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm.hpp:196
static constexpr auto I1
Definition: gridwise_moe_mx_gemm.hpp:178
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm.hpp:194
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm.hpp:198
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm.hpp:598
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:257
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm.hpp:1134
static constexpr auto I5
Definition: gridwise_moe_mx_gemm.hpp:182
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm.hpp:221
static constexpr auto I4
Definition: gridwise_moe_mx_gemm.hpp:181
static constexpr auto I2
Definition: gridwise_moe_mx_gemm.hpp:179
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm.hpp:307
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm.hpp:202
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm.hpp:197
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm.hpp:1089
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:1364
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:267
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm.hpp:192
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1349
Definition: xdlops_gemm.hpp:1126
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1647
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Definition: thread_group_tensor_slice_transfer_direct_load.hpp:55
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: data_type.hpp:41
Definition: integral_constant.hpp:20
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129