38 template <
typename GridwiseGemm,
39 bool HasMainKBlockLoop,
44 #if CK_USE_LAUNCH_BOUNDS
51 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
53 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
55 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
57 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
58 karg.p_sorted_token_ids,
59 karg.p_sorted_expert_ids,
61 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
62 karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
63 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
64 karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
78 template <
typename GridwiseGemm,
79 bool HasMainKBlockLoop,
84 #if CK_USE_LAUNCH_BOUNDS
91 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
93 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
94 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
96 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
98 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
99 karg.p_sorted_token_ids,
100 karg.p_sorted_expert_ids,
102 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
103 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
104 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
105 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
120 template <
typename ALayout,
125 typename AScaleDataType,
127 typename BScaleDataType,
128 typename AccDataType,
129 typename CShuffleDataType,
132 typename AElementwiseOperation,
133 typename BElementwiseOperation,
134 typename CElementwiseOperation,
147 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
148 typename ABlockTransferThreadClusterArrangeOrder,
149 typename ABlockTransferSrcAccessOrder,
150 index_t ABlockTransferSrcVectorDim,
151 index_t ABlockTransferSrcScalarPerVector,
152 index_t ABlockTransferDstScalarPerVector_AK1,
153 bool AThreadTransferSrcResetCoordinateAfterRun,
155 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
156 typename BBlockTransferThreadClusterArrangeOrder,
157 typename BBlockTransferSrcAccessOrder,
158 index_t BBlockTransferSrcVectorDim,
159 index_t BBlockTransferSrcScalarPerVector,
160 index_t BBlockTransferDstScalarPerVector_BK1,
161 bool BThreadTransferSrcResetCoordinateAfterRun,
163 index_t CShuffleMXdlPerWavePerShuffle,
164 index_t CShuffleNXdlPerWavePerShuffle,
165 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
166 typename CDEShuffleBlockTransferScalarPerVectors,
169 index_t ActivationOperation = 0,
170 bool NSwizzle =
false,
171 bool IsInputGemm =
true,
172 bool MulRoutedWeight =
true,
174 typename ComputeTypeA = ADataType,
175 typename ComputeTypeB = BDataType>
193 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
240 "A scale pack data type too large!");
242 "B scale pack data type too large!");
250 return static_cast<const DDataType*
>(
nullptr);
263 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
264 const index_t gridy = NSwizzle ? 1 : mblock;
295 auto K_t = K_Batch * KPerBlock;
296 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
301 auto K_t = K_Batch * KPerBlock;
302 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
307 auto K_t = K_Batch * KPerBlock;
308 return (K + K_t - 1) / K_t * KPerBlock;
314 auto K_t = K_Batch * KReadVec;
315 return (K + K_t - 1) / K_t * KReadVec;
328 template <
index_t MNXdlPerWave,
333 typename TileDesc_K0_MN_K1>
376 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
378 const auto a_grid_desc_mraw_kraw = [&]() {
379 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
383 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
391 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
392 GemmSpec == GemmSpecialization::MNKPadding)
395 const auto a_grid_desc_m_k =
409 return a_grid_desc_ak0_m_ak1;
411 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
412 GemmSpec == GemmSpecialization::MNPadding)
416 a_grid_desc_mraw_kraw,
422 return a_grid_desc_ak0_m_ak1;
424 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
425 GemmSpec == GemmSpecialization::NKPadding)
429 a_grid_desc_mraw_kraw,
441 return a_grid_desc_ak0_m_ak1;
447 a_grid_desc_mraw_kraw,
454 a_grid_desc_ak0_m_ak1,
462 a_grid_desc_permuted,
476 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
477 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
486 const auto b_grid_desc_nraw_kraw = [&]() {
500 GemmSpec != GemmSpecialization::Default),
501 "pk_i4_t does not support padding");
503 GemmSpec != GemmSpecialization::Default),
504 "f4x2_pk_t does not support padding");
506 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
507 GemmSpec == GemmSpecialization::MNKPadding)
510 const auto b_grid_desc_n_k =
524 return b_grid_desc_bk0_n_bk1;
526 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
527 GemmSpec == GemmSpecialization::MNPadding)
531 b_grid_desc_nraw_kraw,
537 return b_grid_desc_bk0_n_bk1;
539 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
540 GemmSpec == GemmSpecialization::MKPadding)
544 b_grid_desc_nraw_kraw,
556 return b_grid_desc_bk0_n_bk1;
562 b_grid_desc_nraw_kraw,
569 b_grid_desc_bk0_n_bk1,
577 b_grid_desc_permuted,
589 template <
typename ABlockDesc_AK0_M_AK1>
590 __host__ __device__
static constexpr
auto
593 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
595 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl, true>(
596 ABlockDesc_AK0_M_AK1{});
599 template <
typename BBlockDesc_BK0_N_BK1>
600 __host__ __device__
static constexpr
auto
603 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
605 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl, false>(
606 BBlockDesc_BK0_N_BK1{});
609 template <
typename ELayout>
611 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
613 const auto c_grid_desc_mraw_nraw = [&]() {
632 template <
typename DLayout>
633 __host__ __device__
static auto
636 const auto c_grid_desc_mraw_nraw = [&]() {
661 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
666 template <
typename DsGr
idDesc>
668 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
673 ds_grid_desc_m_n[i], MBlock, NBlock);
689 std::array<index_t, NumDTensor> StrideDs_,
717 std::cout <<
"problem {" <<
"NumTokens:" <<
NumTokens <<
", " <<
"TopK:" <<
TopK <<
", "
718 <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
722 <<
", " <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", "
723 <<
"AK0:" <<
AK0 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock
724 <<
", " <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
753 const index_t* p_sorted_expert_ids_,
754 const index_t* p_max_token_id_,
755 const ADataType* p_a_grid_,
756 const AScaleDataType* p_a_scale_grid_,
757 const BDataType* p_b_grid_,
758 const BScaleDataType* p_b_scale_grid_,
759 std::array<const void*, NumDTensor> p_ds_grid_,
760 CDataType* p_c_grid_,
770 std::array<index_t, NumDTensor> StrideDs_,
773 AElementwiseOperation a_element_op_,
774 BElementwiseOperation b_element_op_,
775 CElementwiseOperation c_element_op_)
807 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
830 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
834 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
839 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
843 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
857 if(k_id < karg.
KBatch - 1)
875 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
876 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
890 constexpr
auto a_lds_block_desc =
902 return a_lds_block_desc_permuted;
909 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
910 constexpr
auto M1 = MPerBlock / M0;
912 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
913 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
914 constexpr
auto KThreadRead = WaveSize / MPerXdl;
915 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
917 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
919 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
920 constexpr
auto KThreadReadPerm =
921 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
922 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
926 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
928 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
930 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
936 Number<kfold * M0 / mpair>{},
955 a_lds_block_desc_permuted,
977 a_lds_block_desc_unmerged,
980 Number<KThreadWrite / kfold / KThreadReadPerm>{},
989 return a_lds_block_desc_ak0_m_ak1;
1005 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1007 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1014 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1035 ABlockTransferSrcScalarPerVector,
1036 BBlockTransferSrcScalarPerVector,
1055 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1058 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1061 constexpr
auto c_block_size =
1062 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1064 return math::max(a_block_space_size_aligned *
sizeof(ADataType),
1065 c_block_size *
sizeof(CShuffleDataType));
1073 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1074 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1075 "Invalid tuning param!");
1077 static_assert(KPerBlock % (ScaleBlockSize /
BPackedSize) == 0,
1078 "KPerBlock should be multiple of ScaleBlockSize");
1086 if(!(karg.M % MPerBlock == 0))
1090 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
1091 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1104 if(!(karg.N % NPerBlock == 0))
1108 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
1109 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1121 auto K_t = karg.KBatch * KPerBlock;
1122 if(!(karg.K % K_t == 0))
1126 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1127 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1128 <<
", in function: " << __func__ << std::endl;
1136 auto K_t = karg.KBatch * KReadVec;
1138 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1146 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1150 std::cout <<
"Arg K (" << karg.K
1151 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1152 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1153 << __LINE__ <<
", in function: " << __func__ << std::endl;
1160 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1164 std::cout <<
"Arg M (" << karg.M
1165 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1166 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1167 << __LINE__ <<
", in function: " << __func__ << std::endl;
1175 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1179 std::cout <<
"Arg N (" << karg.N
1180 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1181 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1182 << __LINE__ <<
", in function: " << __func__ << std::endl;
1189 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1193 std::cout <<
"Arg K (" << karg.K
1194 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1195 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1196 << __LINE__ <<
", in function: " << __func__ << std::endl;
1208 std::cout <<
"Arg N (" << karg.N
1209 <<
") value is not a multiple of "
1210 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1212 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1224 std::cout <<
"Arg M (" << karg.M
1225 <<
") value is not a multiple of "
1226 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1228 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1238 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1240 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1251 const index_t num_loop = K / KPerBlock;
1253 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1258 const index_t num_loop = K / KPerBlock;
1260 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1263 template <
typename CGr
idDesc>
1265 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1274 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1283 template <
bool HasMainKBlockLoop,
1286 __device__
static void Run(
const index_t* p_sorted_token_ids,
1287 const index_t* p_sorted_expert_ids,
1288 const index_t* p_max_token_id,
1289 const ADataType* p_a_grid,
1290 const AScaleDataType* p_a_scale_grid,
1291 const BDataType* p_b_grid,
1292 const BScaleDataType* p_b_scale_grid,
1294 CDataType* p_c_grid,
1296 const Problem& problem,
1297 AElementwiseOperation a_element_op,
1298 BElementwiseOperation b_element_op,
1299 CElementwiseOperation c_element_op)
1305 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1311 const auto b_grid_desc_bpreshuffled =
1313 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1314 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1332 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1334 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1335 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1337 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1338 if(expert_block_id * MPerBlock >= max_token_id)
1341 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1343 const auto block_mn = [&]() -> std::pair<int, int> {
1344 if constexpr(NSwizzle)
1346 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1347 const index_t prefix_block = ecnt_prefix * problem.NBlock;
1348 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1349 const index_t expert_swizzle =
1350 ecnt > 0 ? ecnt : 1;
1351 const index_t bid_new = blockIdx.x - prefix_block;
1352 const index_t nid = __builtin_amdgcn_readfirstlane(
1353 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1355 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1360 return {blockIdx.x, blockIdx.y};
1364 const index_t block_n_id = block_mn.first;
1365 const index_t block_m_id = block_mn.second;
1367 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1370 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1371 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1372 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1373 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1374 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1375 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1377 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1379 StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
1380 static_for<0, AMRepeats, 1>{}([&](
auto m0) {
1381 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1382 index_t token_offset = fused_token & 0xffffff;
1383 if constexpr(!IsInputGemm)
1385 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1387 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.K /
APackedSize;
1390 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1391 const index_t expert_scale_stride =
1392 __builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) *
1396 const index_t n_block_data_idx_on_grid =
1397 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1399 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1400 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1401 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1402 p_b_grid + expert_id * expert_stride /
BPackedSize,
1403 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1406 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1407 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1408 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1409 p_b_scale_grid + expert_id * expert_scale_stride,
1410 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1419 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1421 AElementwiseOperation,
1424 Sequence<AK0Number, MPerBlock, AK1Number>,
1425 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1426 ABlockTransferThreadClusterArrangeOrder,
1429 decltype(a_grid_desc_ak0_m_ak1),
1430 decltype(a_block_desc_ak0_m_ak1),
1431 ABlockTransferSrcAccessOrder,
1433 ABlockTransferSrcVectorDim,
1435 ABlockTransferSrcScalarPerVector,
1436 ABlockTransferDstScalarPerVector_AK1,
1439 AThreadTransferSrcResetCoordinateAfterRun,
1443 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1446 a_block_desc_ak0_m_ak1,
1453 auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1454 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1456 auto b_blockwise_copy =
1457 ThreadwiseTensorSliceTransfer_v2<BDataType,
1459 decltype(b_grid_desc_bpreshuffled),
1460 decltype(b_block_desc_bk0_n_bk1),
1465 Number<BK1Value>{}>,
1466 Sequence<1, 2, 0, 3>,
1468 BBlockTransferSrcScalarPerVector,
1469 BThreadTransferSrcResetCoordinateAfterRun,
1471 b_grid_desc_bpreshuffled,
1479 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1481 a_block_desc_ak0_m_ak1.GetElementSpaceSize() /
APackedSize);
1487 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1489 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1490 decltype(c_thread_buf) c_thread_buf_up;
1494 c_thread_buf.num_of_v_,
1495 c_thread_buf.s_per_v,
1499 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1500 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1504 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1505 const auto waveId_m = wave_idx[
I0];
1506 const auto waveId_n = wave_idx[
I1];
1508 static constexpr
auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
1510 auto thread_offset_shuffled =
1513 auto a_thread_offset_m = waveId_m;
1515 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1518 decltype(a_scale_grid_desc_am_ak),
1519 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1525 true>(a_scale_grid_desc_am_ak,
1531 auto b_thread_offset_n = waveId_n;
1533 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1536 decltype(b_scale_grid_desc_bn_ak),
1537 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1543 true>(b_scale_grid_desc_bn_ak,
1548 if constexpr(IsInputGemm)
1550 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
1551 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1552 p_b_grid_up + expert_id * expert_stride /
BPackedSize,
1553 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1554 auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1557 decltype(b_grid_desc_bpreshuffled),
1558 decltype(b_block_desc_bk0_n_bk1),
1559 Sequence<Number<NXdlPerWave>{},
I1, Number<KRepeat>{}, Number<BK1Value>{}>,
1560 Sequence<1, 2, 0, 3>,
1562 BBlockTransferSrcScalarPerVector,
1563 BThreadTransferSrcResetCoordinateAfterRun,
1564 true>(b_grid_desc_bpreshuffled,
1569 const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
1570 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1571 p_b_scale_grid_up + expert_id * expert_scale_stride,
1572 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1573 auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1576 decltype(b_scale_grid_desc_bn_ak),
1577 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1584 b_scale_grid_desc_bn_ak,
1589 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1590 a_grid_desc_ak0_m_ak1,
1591 a_block_desc_ak0_m_ak1,
1595 a_block_slice_copy_step,
1596 b_grid_desc_bpreshuffled,
1597 b_block_desc_bk0_n_bk1,
1599 b_blockwise_copy_up,
1603 b_block_slice_copy_step,
1606 a_scale_grid_desc_am_ak,
1607 a_scale_thread_copy,
1609 b_scale_grid_desc_bn_ak,
1610 b_scale_thread_copy,
1611 b_scale_thread_copy_up,
1613 b_scale_grid_buf_up,
1614 num_k_block_main_loop);
1618 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1619 a_grid_desc_ak0_m_ak1,
1620 a_block_desc_ak0_m_ak1,
1624 a_block_slice_copy_step,
1625 b_grid_desc_bpreshuffled,
1626 b_block_desc_bk0_n_bk1,
1630 b_block_slice_copy_step,
1632 a_scale_grid_desc_am_ak,
1633 a_scale_thread_copy,
1635 b_scale_grid_desc_bn_ak,
1636 b_scale_thread_copy,
1638 num_k_block_main_loop);
1643 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1644 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1648 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1649 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1653 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1654 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1656 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1657 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1658 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1659 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1660 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1661 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1662 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1663 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1666 static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
1667 static_assert(M4 == 4);
1671 vector_type<float, 4> topk_weights;
1672 static_for<0, NXdlPerWave, 1>{}([&](
auto n0) {
1673 static_for<0, MXdlPerWave, 1>{}([&](
auto m0) {
1674 static_for<0, M2, 1>{}([&](
auto m2) {
1675 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1676 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1677 if constexpr(MulRoutedWeight)
1679 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1680 p_ds_grid[
I2] + m_pos);
1682 static_for<0, M4, 1>{}([&](
auto m4) {
1684 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1686 constexpr
auto cidx = Number<c_offset>{};
1688 if constexpr(IsInputGemm)
1692 float gate = c_thread_buf[cidx];
1693 float up = c_thread_buf_up[cidx];
1694 if constexpr(MulRoutedWeight)
1696 gate = gate * topk_weights.AsType<
float>()[m4];
1697 up = up * topk_weights.AsType<
float>()[m4];
1699 tensor_operation::element_wise::Silu{}(gate, gate);
1700 c_thread_buf_fp32(cidx) = gate * up;
1704 float gate = c_thread_buf[cidx];
1705 float up = c_thread_buf_up[cidx];
1706 if constexpr(MulRoutedWeight)
1708 gate = gate * topk_weights.AsType<
float>()[m4];
1709 up = up * topk_weights.AsType<
float>()[m4];
1711 tensor_operation::element_wise::Gelu{}(gate, gate);
1712 c_thread_buf_fp32(cidx) = gate * up;
1717 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1718 if constexpr(MulRoutedWeight)
1720 c_thread_buf_fp32(cidx) =
1721 topk_weights.AsType<
float>()[m4] * c_thread_buf_fp32[cidx];
1729 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1732 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1733 static_cast<CShuffleDataType*
>(p_shared),
1734 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1737 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1741 Number<CShuffleMXdlPerWavePerShuffle>{},
1748 Number<CShuffleNXdlPerWavePerShuffle>{},
1751 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
1753 Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
1757 const auto c_thread_mtx_on_block =
1758 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1760 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1761 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1763 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1769 const auto m_thread_data_on_block_idx =
1770 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1773 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1779 const auto n_thread_data_on_block_idx =
1780 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1784 auto c_thread_copy_vgpr_to_lds =
1785 ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
1787 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1788 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1790 Sequence<CShuffleMXdlPerWavePerShuffle,
1791 CShuffleNXdlPerWavePerShuffle,
1798 Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
1804 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1807 m_thread_data_on_block_idx[
I1],
1808 n_thread_data_on_block_idx[
I1],
1809 m_thread_data_on_block_idx[
I2],
1810 m_thread_data_on_block_idx[
I3],
1811 m_thread_data_on_block_idx[
I4],
1812 n_thread_data_on_block_idx[
I2]),
1815 using EDataType = CDataType;
1818 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1820 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1822 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1826 return make_dynamic_buffer<AddressSpaceEnum::Global>(
1827 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1829 Number<NumDTensor>{});
1833 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1835 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1836 Number<NumDTensor>{}));
1840 tie(c_shuffle_block_buf),
1842 {
return ds_grid_buf[i]; },
1843 Number<NumDTensor>{}));
1846 const auto idx_c_ds_block_begin =
1854 Number<NumDTensor>{}));
1856 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1857 c_grid_desc_mblock_mperblock_nblock_nperblock;
1859 using CDEBlockTransferCluster =
1860 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1861 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1862 constexpr
index_t scatter_weight_idx = 1;
1863 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1867 decltype(c_ds_desc_refs),
1868 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1869 CElementwiseOperation,
1870 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>,
1873 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1875 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
1876 CDEBlockTransferCluster,
1877 Sequence<0, 1, 2, 3>,
1878 Sequence<0, 1, 2, 3>,
1879 Sequence<0, 1, 2, 3>,
1882 CDEShuffleBlockTransferScalarPerVectors,
1894 idx_c_ds_block_begin,
1895 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1899 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1900 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1901 constexpr
auto sfc_c_vgpr =
1902 SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
1903 Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
1904 Sequence<CShuffleMXdlPerWavePerShuffle,
1905 CShuffleNXdlPerWavePerShuffle,
1913 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1916 constexpr
auto sfc_cde_block =
1917 SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
1918 Sequence<0, 2, 1, 3>,
1920 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1922 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
1924 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
1925 constexpr
auto EMThreads =
1926 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
1927 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1928 constexpr
auto ENThreads =
1929 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
1930 static_for<0, num_access, 1>{}([&](
auto access_id) {
1932 StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
1934 auto dstidx = sfc_cde_block.GetIndex(access_id);
1936 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
1937 static_for<0, EMRepeats, 1>{}([&](
auto m0) {
1938 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1939 IndexType token_offset = fused_token & 0xffffff;
1940 if constexpr(IsInputGemm)
1942 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1944 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.N;
1950 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1951 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1953 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1954 c_shuffle_block_buf);
1960 cde_block_copy_lds_and_global.Run(
1963 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1967 if constexpr(access_id < num_access - 1)
1969 constexpr
auto cde_lds_and_global_step =
1970 sfc_cde_block.GetForwardStep(access_id);
1973 static_for<0, NumDTensor, 1>{}([&](
auto i) {
1974 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1975 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
1979 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1980 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1982 cde_lds_and_global_step);
1989 template <
bool HasMainKBlockLoop,
1993 const index_t* p_sorted_expert_ids,
1994 const index_t* p_max_token_id,
1995 const ADataType* p_a_grid,
1996 const AScaleDataType* p_a_scale_grid,
1997 const BDataType* p_b_grid,
1998 const BScaleDataType* p_b_scale_grid,
2000 CDataType* p_c_grid,
2004 AElementwiseOperation a_element_op,
2005 BElementwiseOperation b_element_op,
2006 CElementwiseOperation c_element_op)
2019 const auto b_grid_desc_bpreshuffled =
2021 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2029 const auto Padded_Scale_M =
2053 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2057 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2058 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
2059 if(expert_block_id * MPerBlock >= max_token_id)
2062 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2063 const auto block_mn = [&]() -> std::pair<int, int> {
2064 if constexpr(NSwizzle)
2066 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2068 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2069 const index_t expert_swizzle =
2070 ecnt > 0 ? ecnt : 1;
2071 const index_t bid_new = blockIdx.x - prefix_block;
2072 const index_t nid = __builtin_amdgcn_readfirstlane(
2073 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2075 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2080 return {blockIdx.x, blockIdx.y};
2084 const index_t block_n_id = block_mn.first;
2085 const index_t block_m_id = block_mn.second;
2087 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2090 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
2091 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
2092 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
2093 constexpr
auto AKThreads = AK0Threads * AK1Threads;
2094 constexpr
auto AMRepeats = MPerBlock / AMThreads;
2095 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2097 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
2101 const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2102 index_t token_offset = fused_token & 0xffffff;
2103 if constexpr(!IsInputGemm)
2105 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2107 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
2111 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
2112 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2113 problem.
N * (IsInputGemm ? 2 : 1) *
2117 const index_t n_block_data_idx_on_grid =
2118 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave /
NXdlPack);
2121 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2122 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2123 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2124 p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2127 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2128 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2129 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2130 p_b_scale_grid + (expert_id * expert_scale_stride) /
sizeof(BScaleDataType),
2131 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2143 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2144 ABlockTransferThreadClusterArrangeOrder,
2147 decltype(a_grid_desc_ak0_m_ak1),
2148 decltype(a_block_desc_ak0_m_ak1),
2149 ABlockTransferSrcAccessOrder,
2150 ABlockTransferSrcVectorDim,
2152 ABlockTransferSrcScalarPerVector,
2154 1>(a_grid_desc_ak0_m_ak1,
2156 a_block_desc_ak0_m_ak1,
2162 auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2163 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2164 auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2165 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2166 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2168 auto b_blockwise_copy =
2171 decltype(b_grid_desc_bpreshuffled),
2172 decltype(b_block_desc_bk0_n_bk1),
2180 BBlockTransferSrcScalarPerVector,
2181 BThreadTransferSrcResetCoordinateAfterRun,
2183 b_grid_desc_bpreshuffled,
2192 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2193 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2194 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2195 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2196 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2202 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2204 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2205 decltype(c_thread_buf) c_thread_buf_up;
2209 c_thread_buf.num_of_v_,
2210 c_thread_buf.s_per_v,
2214 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2215 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2219 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2220 const auto waveId_m = wave_idx[
I0];
2221 const auto waveId_n = wave_idx[
I1];
2223 auto thread_offset_shuffled =
2226 auto a_thread_offset_m = waveId_m;
2229 const index_t token_scale_pos = block_m_id * MPerBlock;
2230 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
2236 decltype(a_scale_grid_desc_am_ak),
2237 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2243 true>(a_scale_grid_desc_am_ak,
2249 auto b_thread_offset_n = waveId_n;
2254 decltype(b_scale_grid_desc_bn_ak),
2255 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2261 true>(b_scale_grid_desc_bn_ak,
2266 if constexpr(IsInputGemm)
2268 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2269 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2270 p_b_grid_up + expert_id * expert_stride,
2271 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2272 auto b_blockwise_copy_up =
2275 decltype(b_grid_desc_bpreshuffled),
2276 decltype(b_block_desc_bk0_n_bk1),
2284 BBlockTransferSrcScalarPerVector,
2285 BThreadTransferSrcResetCoordinateAfterRun,
2287 b_grid_desc_bpreshuffled,
2293 const BScaleDataType* p_b_scale_grid_up =
2294 p_b_scale_grid + expert_scale_stride / 2 /
sizeof(BScaleDataType);
2295 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2296 p_b_scale_grid_up + expert_id * expert_scale_stride /
sizeof(BScaleDataType),
2297 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2302 decltype(b_scale_grid_desc_bn_ak),
2303 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2310 b_scale_grid_desc_bn_ak,
2315 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2317 a_grid_desc_ak0_m_ak1,
2318 a_block_desc_ak0_m_ak1,
2322 a_block_slice_copy_step,
2324 b_grid_desc_bpreshuffled,
2325 b_block_desc_bk0_n_bk1,
2327 b_blockwise_copy_up,
2331 b_block_slice_copy_step,
2336 a_scale_grid_desc_am_ak,
2337 a_scale_thread_copy,
2340 b_scale_grid_desc_bn_ak,
2341 b_scale_thread_copy,
2342 b_scale_thread_copy_up,
2344 b_scale_grid_buf_up,
2345 num_k_block_main_loop);
2349 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2350 a_grid_desc_ak0_m_ak1,
2351 a_block_desc_ak0_m_ak1,
2355 a_block_slice_copy_step,
2356 b_grid_desc_bpreshuffled,
2357 b_block_desc_bk0_n_bk1,
2361 b_block_slice_copy_step,
2363 a_scale_grid_desc_am_ak,
2364 a_scale_thread_copy,
2366 b_scale_grid_desc_bn_ak,
2367 b_scale_thread_copy,
2369 num_k_block_main_loop);
2374 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2375 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2377 static_assert(CShuffleMXdlPerWavePerShuffle %
MXdlPack == 0 &&
2378 CShuffleNXdlPerWavePerShuffle %
NXdlPack == 0,
2381 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2384 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2385 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2389 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2390 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2392 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2393 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2394 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2395 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2396 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2397 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2398 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2399 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2400 constexpr
auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
2401 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
2405 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2406 static_assert(M5 == 4);
2416 const index_t m_pos = block_m_id * MPerBlock +
2417 m0 * M2 * M1 * M3 * M4 * M5 +
2418 m1 * M2 * M3 * M4 * M5 +
2419 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2420 if constexpr(MulRoutedWeight)
2423 *c_style_pointer_cast<const vector_type<float, M5>*>(
2424 p_ds_grid[
I2] + m_pos);
2428 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2429 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2432 if constexpr(IsInputGemm)
2434 if constexpr(ActivationOperation ==
2437 float gate = c_thread_buf[cidx];
2438 float up = c_thread_buf_up[cidx];
2439 if constexpr(MulRoutedWeight)
2441 gate = gate * topk_weights.AsType<
float>()[m5];
2442 up = up * topk_weights.AsType<
float>()[m5];
2445 c_thread_buf_fp32(cidx) = gate * up;
2449 float gate = c_thread_buf[cidx];
2450 float up = c_thread_buf_up[cidx];
2451 if constexpr(MulRoutedWeight)
2453 gate = gate * topk_weights.AsType<
float>()[m5];
2454 up = up * topk_weights.AsType<
float>()[m5];
2457 c_thread_buf_fp32(cidx) = gate * up;
2462 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2463 if constexpr(MulRoutedWeight)
2465 c_thread_buf_fp32(cidx) =
2466 topk_weights.AsType<
float>()[m5] *
2467 c_thread_buf_fp32[cidx];
2477 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2480 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2481 static_cast<CShuffleDataType*
>(p_shared_0),
2482 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2485 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2511 const auto c_thread_mtx_on_block =
2512 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2514 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2515 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2517 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2523 const auto m_thread_data_on_block_idx =
2524 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2527 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2533 const auto n_thread_data_on_block_idx =
2534 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2541 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2542 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2545 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2554 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2559 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2562 m_thread_data_on_block_idx[
I1],
2563 n_thread_data_on_block_idx[
I1],
2564 m_thread_data_on_block_idx[
I2],
2565 n_thread_data_on_block_idx[
I2],
2566 m_thread_data_on_block_idx[
I3],
2567 m_thread_data_on_block_idx[
I4],
2568 m_thread_data_on_block_idx[
I5],
2569 n_thread_data_on_block_idx[
I3]),
2572 using EDataType = CDataType;
2577 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2583 return make_dynamic_buffer<AddressSpaceEnum::Global>(
2584 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2590 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2592 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2597 tie(c_shuffle_block_buf),
2599 {
return ds_grid_buf[i]; },
2603 const auto idx_c_ds_block_begin =
2613 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2614 c_grid_desc_mblock_mperblock_nblock_nperblock;
2616 using CDEBlockTransferCluster =
2617 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2618 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2619 constexpr
index_t scatter_weight_idx = 3;
2624 decltype(c_ds_desc_refs),
2625 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2626 CElementwiseOperation,
2631 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2633 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
2634 CDEBlockTransferCluster,
2640 CDEShuffleBlockTransferScalarPerVectors,
2652 idx_c_ds_block_begin,
2653 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2657 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2658 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2660 constexpr
auto sfc_c_vgpr =
2671 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2673 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2683 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2686 constexpr
auto sfc_cde_block =
2690 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2692 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
2694 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2695 constexpr
auto EMThreads =
2696 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2697 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2698 constexpr
auto ENThreads =
2699 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2704 auto dstidx = sfc_cde_block.GetIndex(access_id);
2706 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2708 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2709 IndexType token_offset = fused_token & 0xffffff;
2710 if constexpr(IsInputGemm)
2712 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2714 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
N;
2720 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2721 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2723 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2724 c_shuffle_block_buf);
2730 cde_block_copy_lds_and_global.Run(
2733 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2737 if constexpr(access_id < num_access - 1)
2739 constexpr
auto cde_lds_and_global_step =
2740 sfc_cde_block.GetForwardStep(access_id);
2744 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2745 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2749 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2750 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2752 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp:36
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:90
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:751
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:812
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:811
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:817
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:822
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:816
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:813
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:818
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:821
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:819
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:815
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:752
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:823
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:814
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:679
index_t AK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:743
index_t NPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:740
index_t KBatch
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:738
index_t BK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:744
index_t TopK
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:728
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:680
index_t KRead
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:741
index_t K
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:731
index_t MPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:739
index_t StrideC
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:737
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:735
index_t NumTokens
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:727
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:736
index_t MBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:745
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:733
index_t N
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:730
__host__ void Print() const
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:715
index_t StrideA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:732
index_t StrideB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:734
index_t M
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:729
index_t KPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:742
index_t NBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:746
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:827
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:869
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:870
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:867
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:828
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:868
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:177
static constexpr auto I6
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:187
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:305
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:197
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:195
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1047
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:255
static constexpr auto I1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:182
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:288
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:259
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:655
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1992
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1264
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:667
static constexpr index_t NLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:228
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:993
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:234
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:238
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:198
remove_cvref_t< decltype(BlockGemmMXBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1045
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1256
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:274
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:244
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:375
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:283
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:196
static constexpr auto I9
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:190
static constexpr auto I4
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:185
static constexpr auto I3
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:184
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:311
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:293
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:269
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:192
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:318
static constexpr index_t KRepeat
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:231
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:207
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1003
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:208
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:179
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:483
static constexpr auto I0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:181
static constexpr auto I5
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:186
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:237
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:200
static constexpr index_t KLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:229
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:873
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:206
static constexpr auto I8
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:189
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:279
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:601
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:610
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:201
static constexpr auto I2
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:183
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:591
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:216
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:204
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:323
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:217
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:225
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1249
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1071
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:178
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:299
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:634
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:334
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:257
static constexpr auto I7
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:188
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:474
static constexpr index_t NWave
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:230
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:202
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:26
Definition: data_type.hpp:42
Definition: integral_constant.hpp:20
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129