17 template <
class ScaleM = FlatmmScalePointer<-1>,
18 class ScaleN = FlatmmScalePointer<-1>,
19 class ExpertBias = FlatmmScalePointer<-1>>
36 const
void* p_sorted_expert_weights_,
54 ExpertBias exp_bias_ = {})
56 p_sorted_expert_weights_,
81 const void* p_sorted_expert_weights_,
100 ScaleN scale_n_ = {},
101 ExpertBias exp_bias_ = {})
102 : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>(a_ptr_,
142 template <
typename T>
146 return gate * linear;
156 Swiglu(
float alpha_ = 1.702f,
float limit_ = 7.0f)
161 template <
typename T>
164 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
165 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
166 std::is_same_v<T, int32_t>,
167 "Data type is not supported by this operation!");
169 constexpr T one = type_convert<T>(1);
174 if constexpr(std::is_same_v<T, float>)
176 return gate * __builtin_amdgcn_rcpf(one +
ck_tile::exp(
alpha * -gate)) * (linear + 1);
187 template <
typename TilePartitioner_,
188 typename FlatmmPipeline_,
189 typename EpiloguePipeline_,
191 typename FusedActivation = moe::MoeSilu>
223 static_assert(DsLayout::size() == DsDataType::size(),
224 "The size of DsLayout and DsDataType should be the same");
245 IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
249 std::is_same_v<ADataType, fp8_t> ||
250 std::is_same_v<ADataType, pk_fp4_t>;
329 '_',
"moe_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
336 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
338 template <
class MoeFlatmmKernelArgs>
343 hipDeviceProp_t prop;
347 int dync_smem_size = 0;
348 int maxActiveBlocksPerCU = 0;
350 [[maybe_unused]]
auto e = hipGetDeviceProperties(&prop, deviceId);
352 e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
353 &maxActiveBlocksPerCU,
354 reinterpret_cast<void*
>(kentry<1, MoeFlatmmKernel, MoeFlatmmKernelArgs>),
358 const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
359 const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.
M, kargs.
N);
366 return dim3(
min(persistent_block_size, total_work_tile_cnt), 1, kargs.
k_batch);
370 return dim3(TilePartitioner::GridSize(kargs.
M, kargs.
N), 1, kargs.
k_batch);
376 return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
380 return FlatmmPipeline::GetSmemSize();
385 template <
class KernelArgs>
388 constexpr
auto K1 = BlockGemmShape::WarpTile::at(
number<2>{});
389 const index_t K_t = kargs.k_batch * K1;
390 const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
392 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
396 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
401 if(k_id <
static_cast<uint32_t>(kargs.k_batch - 1))
407 splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
416 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
420 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
432 template <
typename KernelArgs>
435 if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
438 if(kargs.k_batch != 1)
440 std::cerr <<
"Conditions not met for Kbatch >1 !" << std::endl;
446 if(kargs.k_batch != 1)
448 std::cerr <<
"Persistent mode doesn't support Kbatch >1 !" << std::endl;
453 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
455 if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK ==
false)
457 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
462 if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
464 std::cerr <<
"K is not a multiple of vector load size for A tensor!" << std::endl;
470 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
472 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
477 if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
479 std::cerr <<
"M is not a multiple of vector load size for A tensor!" << std::endl;
484 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
493 if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
495 std::cerr <<
"N is not a multiple of vector load size for B tensor!" << std::endl;
501 if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK ==
false)
503 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
508 if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
510 std::cerr <<
"K is not a multiple of vector load size for B tensor!" << std::endl;
515 bool DTesnorIsValid = {
true};
518 if(std::is_same_v<DiLayout, ELayout> ==
false)
520 DTesnorIsValid =
false;
522 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
524 if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
526 CK_TILE_ERROR(
"Can't support N for tensor D that is not a multiple of "
527 "NPerBlock without padding!");
528 DTesnorIsValid =
false;
530 if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
532 CK_TILE_ERROR(
"N is not a multiple of vector load size for D tensor!");
533 DTesnorIsValid =
false;
538 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
540 CK_TILE_ERROR(
"Can't support M for tensor D that is not a multiple of "
541 "MPerBlock without padding!");
543 DTesnorIsValid =
false;
545 if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
547 CK_TILE_ERROR(
"M is not a multiple of vector load size for D tensor!");
548 DTesnorIsValid =
false;
553 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
555 if(kargs.stride_C % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
557 std::cerr <<
"Can't support N that is not a multiple of NPerBlock"
562 if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
564 std::cerr <<
"N is not a multiple of vector load size for C tensor!" << std::endl;
570 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
572 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
577 if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
579 std::cerr <<
"M is not a multiple of vector load size for C tensor!" << std::endl;
583 return DTesnorIsValid;
587 ? memory_operation_enum::set
594 [[maybe_unused]]
const AccDataType* exp_weight_ptr,
595 [[maybe_unused]]
const int expert_id,
596 const KernelArgs& kargs,
599 const auto& a_tensor_view = [&]() {
600 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
602 return make_naive_tensor_view<address_space_enum::global>(
607 number<FlatmmPipeline::GetVectorSizeA()>{},
612 return make_naive_tensor_view<address_space_enum::global>(
615 IsInputGemm ? kargs.NumTokens : kargs.NumTokens * kargs.TopK),
617 number<FlatmmPipeline::GetVectorSizeA()>{},
622 const auto& b_flat_tensor_view = [&]() {
623 if constexpr(!FlatmmPipeline::BPreShufflePermute)
626 kargs.K * BlockGemmShape::WarpTile::at(
I1);
627 index_t kFlatN = kargs.N * kargs.K / kFlatK;
630 memory_operation_enum::set,
631 FlatmmPipeline::BMemNTType>(
635 number<FlatmmPipeline::GetVectorSizeB()>{},
640 index_t kFlatK = FlatmmPipeline::flatKPerWarp;
641 index_t kFlatN0 = (kargs.N >> 4);
642 index_t kFlatK0 = (kargs.K >> 7);
645 memory_operation_enum::set,
646 FlatmmPipeline::BMemNTType>(
650 number<FlatmmPipeline::GetVectorSizeB()>{},
663 const auto& c_tensor_view = [&]() {
664 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
666 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
671 number<EpiloguePipeline::GetVectorSizeC()>{},
676 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
686 const auto& scale_a_tensor_view = [&]() {
687 auto scale_m_desc = kargs.scale_m;
690 constexpr
int AGranularityK = decltype(scale_m_desc)::GranularityK == 0
692 : decltype(scale_m_desc)::GranularityK;
694 constexpr
int MThreadPerXdl = BlockGemmShape::WarpTile::at(
I0);
695 constexpr
int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(
I0);
700 make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl));
707 return make_tensor_view<address_space_enum::global>(
708 reinterpret_cast<const int32_t*
>(scale_m_desc.ptr), scale_a_desc);
712 constexpr
int AGranularityK = 32;
713 constexpr
int MThreadPerXdl = BlockGemmShape::WarpTile::at(
I0);
714 constexpr
int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(
I0);
717 return make_naive_tensor_view<address_space_enum::global>(
718 reinterpret_cast<const int32_t*
>(scale_m_desc.ptr),
719 make_tuple(scale_m_packs * MThreadPerXdl, scale_k_packs * KThreadPerXdl),
726 const auto scale_b_flat_view = [&]() {
727 auto scale_n = kargs.scale_n;
728 constexpr
int BGranularityK =
729 decltype(scale_n)::GranularityK == 0 ? 1 : decltype(scale_n)::GranularityK;
733 BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
734 constexpr
int NThreadPerXdl = BlockGemmShape::WarpTile::at(
I1);
735 constexpr
int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(
I1);
739 make_tuple(scale_n_packs, scale_k_packs, KThreadPerXdl, NThreadPerXdl));
747 return make_tensor_view<address_space_enum::global>(
748 reinterpret_cast<const int32_t*
>(scale_n.ptr) +
749 expert_id * kargs.N * scale_k / 4,
755 BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
756 const auto scale_k_offset =
758 index_t FlatScaleK = scale_k *
N_Pack * BlockGemmShape::WarpTile::at(
I1);
759 index_t FlatScaleN = kargs.N /
N_Pack / BlockGemmShape::WarpTile::at(
I1);
761 return make_naive_tensor_view<address_space_enum::global>(
762 scale_n.ptr + expert_id * kargs.N * scale_k + scale_k_offset,
777 template <
typename TensorView>
780 const auto& a_pad_view = [&]() {
781 const auto& a_tensor_view = views.at(
I0);
782 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
799 const auto& c_pad_view = [&]() {
800 const auto& c_tensor_view = views.at(
I2);
801 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
817 return make_tuple(a_pad_view, views.at(
I1), c_pad_view, views.at(
I3), views.at(
I4));
820 template <
typename PadView>
822 [[maybe_unused]]
const index_t coord_m,
825 const auto& a_pad_view = views.at(
number<0>{});
826 const auto& b_flat_pad_view = views.at(
number<1>{});
827 const auto& c_pad_view = views.at(
number<2>{});
829 const auto& a_block_window = [&]() {
830 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
848 const auto& b_flat_block_window =
852 {
static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(
I1) /
853 (isNonInterleaveGateUp ? 1 : 2)),
856 const int output_N_offset =
IsGateUp ? coord_n / 2 : coord_n;
864 constexpr
int GranularityK = 32;
868 number<TilePartitioner::KPerBlock / (GranularityK *
K_Pack)>{}),
871 constexpr
int XDLPerLoadScaleB =
874 auto b_scale_block_window = [&]() {
880 number<TilePartitioner::KPerBlock / (GranularityK *
K_Pack)>{}),
889 XDLPerLoadScaleB / GranularityK>{}),
890 {coord_n / BlockGemmShape::WarpTile::at(
I1) /
N_Pack, 0});
897 a_scale_block_window,
898 b_scale_block_window);
901 template <
class MoeFlatmmKernelArgs>
904 int partition_idx = blockIdx.x;
905 int total_work_tile_cnt = TilePartitioner::GridSize(kargs.
M, kargs.
N);
908 const auto [block_offset_m, block_offset_n] =
911 this->
operator()(kargs, block_offset_m, block_offset_n);
912 partition_idx += gridDim.x;
916 template <
class MoeFlatmmKernelArgs>
921 const index_t coord_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
922 const index_t coord_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
930 constexpr
auto a_dram_dist = FlatmmPipeline::GetADramTileDistribution();
931 const auto a_coord = a_dram_dist.calculate_index();
937 constexpr
index_t token_id_offset = 24;
938 constexpr
index_t token_id_mask = (1 << token_id_offset) - 1;
940 auto row_to_token_idx = [&](
auto row_idx) {
943 index_t gather_token_id = fused_token & token_id_mask;
946 gather_token_id = gather_token_id * kargs.
TopK + (fused_token >> token_id_offset);
948 return gather_token_id;
951 if(coord_m >= max_token_id)
955 coord_m + m0 * (TilePartitioner::MPerBlock / DramMRepeat) + a_coord[
I0];
956 index_t gather_token_id = row_to_token_idx(row_idx);
957 a_offsets[m0] = std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
977 a_ptr, b_flat_ptr, e_ptr, exp_weight_ptr, expert_id, kargs, splitk_batch_offset);
982 const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k);
985 const auto& a_block_window = gemm_tile_windows.at(
I0);
986 const auto& b_block_window = gemm_tile_windows.at(
I1);
987 const auto& a_scale_block_window = gemm_tile_windows.at(
I3);
988 const auto& b_scale_block_window = gemm_tile_windows.at(
I4);
990 auto a_gather_block_tile =
992 a_block_window.get_window_lengths(),
993 a_block_window.get_window_origin(),
997 auto c_block_tile = [&] {
1005 a_gather_block_tile,
1007 a_scale_block_window,
1008 b_scale_block_window,
1016 a_gather_block_tile,
1018 b_scale_block_window,
1036 auto& c_block_window = gemm_tile_windows.at(
number<2>{});
1040 using EpiProblem =
typename EpiloguePipeline::Problem;
1041 using ODataType =
typename EpiloguePipeline::ODataType;
1042 using CWarpDstr =
typename EpiloguePipeline::CWarpDstr;
1044 constexpr
index_t NumMXdlPerWavePerShuffle = EpiloguePipeline::NumMXdlPerWavePerShuffle;
1045 constexpr
index_t NumNXdlPerWavePerShuffle = EpiloguePipeline::NumNXdlPerWavePerShuffle;
1046 constexpr
index_t MPerIterationShuffle = EpiloguePipeline::MPerIterationShuffle;
1047 constexpr
index_t NPerIterationShuffle = EpiloguePipeline::NPerIterationShuffle;
1049 constexpr
index_t MRepeat = EpiloguePipeline::MRepeat;
1050 constexpr
index_t NRepeat = EpiloguePipeline::NRepeat;
1053 [[maybe_unused]] constexpr
index_t EpiVectorSizeC = EpiloguePipeline::GetVectorSizeC();
1054 [[maybe_unused]] constexpr
index_t BlockedXDLN_PerWarp =
1055 EpiloguePipeline::BlockedXDLN_PerWarp;
1057 static_assert(!
IsGateUp || NumNXdlPerWavePerShuffle % 2 == 0);
1059 constexpr
index_t OutputNumNXdlPerWavePerShuffle =
1060 IsGateUp ? NumNXdlPerWavePerShuffle / 2 : NumNXdlPerWavePerShuffle;
1061 constexpr
index_t LDS_NPerIterationShuffle =
1062 IsGateUp ? NPerIterationShuffle / 2 : NPerIterationShuffle;
1069 auto o_lds_block = make_tensor_view<address_space_enum::lds>(
1070 reinterpret_cast<ODataType*
>(smem_ptr_ping), lds_block_desc);
1072 constexpr
int ScaleGranularityM = decltype(kargs.
scale_m)::GranularityMN;
1073 constexpr
int ScaleGranularityN = decltype(kargs.
scale_n)::GranularityMN;
1075 constexpr
index_t scale_stride_m = ScaleGranularityM == 0 ? 0
1077 constexpr
index_t scale_stride_n = ScaleGranularityN == 0 ? 0
1080 auto output_acc_tile_distr =
1089 typename CWarpDstr::DstrEncode{}));
1091 const auto scale_m_coord =
1092 output_acc_tile_distr.calculate_index();
1098 constexpr
index_t ScaleMRepeat = MRepeat * kM0 * kM2;
1105 const auto row_idx =
1106 coord_m + mIter *
MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[
I0];
1108 row_to_token_idx(row_idx);
1113 constexpr
int DynamicTileOffsetFlag = 0;
1115 constexpr
bool EnableBias = decltype(kargs.
exp_bias)::GranularityMN != -1;
1117 auto permute_tensor_view = [&](
auto naive_view,
auto is_needed_to_permute_N_PACK) {
1118 if constexpr(!is_needed_to_permute_N_PACK)
1149 auto scale_m_window =
1159 output_acc_tile_distr,
1163 make_naive_tensor_view<address_space_enum::global>(
1164 kargs.
scale_n.ptr + expert_id * kargs.
N,
1167 number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
1172 : TilePartitioner::NPerBlock > {}),
1173 {0,
IsGateUp ? coord_n / 2 : coord_n},
1174 output_acc_tile_distr);
1177 make_naive_tensor_view<address_space_enum::global>(
1178 kargs.
scale_n.ptr + expert_id * kargs.
N + kargs.
N / 2,
1181 number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
1184 number<TilePartitioner::NPerBlock / 2>{}),
1186 output_acc_tile_distr);
1188 auto exp_bias_view = make_naive_tensor_view<address_space_enum::global>(
1189 kargs.
exp_bias.ptr + expert_id * kargs.
N,
1192 number<FlatmmPipeline::GetVectorSizeB()>{},
1199 : TilePartitioner::NPerBlock > {}),
1200 {0,
IsGateUp ? coord_n / 2 : coord_n},
1201 output_acc_tile_distr);
1203 auto exp_bias_up_window =
1205 kargs.
exp_bias.ptr + expert_id * kargs.
N + kargs.
N / 2,
1208 number<FlatmmPipeline::GetVectorSizeB()>{},
1211 number<TilePartitioner::NPerBlock / 2>{}),
1213 output_acc_tile_distr);
1215 auto exp_weight_window =
1220 number<FlatmmPipeline::GetVectorSizeA()>{},
1225 output_acc_tile_distr);
1227 using ScaleMBuffer = decltype(
load_tile(scale_m_window));
1228 using ScaleNBuffer = decltype(
load_tile(scale_n_window));
1229 using ExpBiasBuffer = decltype(
load_tile(exp_bias_window));
1230 using ExpWeightBuffer = decltype(
load_tile(exp_weight_window));
1232 ScaleMBuffer scale_m_buffer;
1233 ScaleNBuffer scale_n_buffer, scale_n_up_buffer;
1235 ExpBiasBuffer exp_bias_buffer, exp_bias_up_buffer;
1236 ExpWeightBuffer exp_weight_buffer;
1240 scale_m_window.load(scale_m_buffer);
1241 scale_n_buffer =
load_tile(scale_n_window);
1243 scale_n_up_buffer =
load_tile(scale_n_up_window);
1246 if constexpr(EnableBias)
1248 exp_bias_buffer =
load_tile(exp_bias_window);
1250 exp_bias_up_buffer =
load_tile(exp_bias_up_window);
1253 exp_weight_buffer =
load_tile(exp_weight_window);
1269 constexpr
index_t num_access = SFC::get_num_of_access();
1271 static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
1272 "Currently, the CShuffle EpiloguePipeline only supports the Row Major "
1277 MPerIterationShuffle,
1278 LDS_NPerIterationShuffle,
1281 EpiProblem::kNumWaveGroups>;
1283 constexpr
auto dram_tile_distribution =
1284 TileEncodingPattern::make_2d_static_tile_distribution();
1286 constexpr
auto LdsTileDistr = [&] {
1299 typename CWarpDstr::DstrEncode{}));
1302 EpiloguePipeline::MakeLdsDistributionEncode());
1305 using LDSTileTensor =
1306 decltype(make_static_distributed_tensor<AccDataType>(LdsTileDistr));
1307 LDSTileTensor lds_tile[2];
1309 constexpr
auto c_warp_y_lengths =
1310 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
1312 constexpr
int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle *
1313 OutputNumNXdlPerWavePerShuffle;
1315 auto epi_tile_idx_slice =
1316 [&](
const auto& acc_tile_like_tensor,
auto epi_m_idx,
auto epi_n_idx) {
1317 return acc_tile_like_tensor.get_y_sliced_thread_data(
1319 epi_n_idx * OutputNumNXdlPerWavePerShuffle>{},
1320 c_warp_y_index_zeros),
1326 auto gate_up_epi_tile_idx_interleave_slice = [&](
auto& dest_gate_tensor,
1327 auto& dest_up_tensor,
1328 const auto& acc_tile_like_tensor,
1332 dest_gate_tensor.set_y_sliced_thread_data(
1335 acc_tile_like_tensor.get_y_sliced_thread_data(
1337 sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
1338 epi_n_idx * NumNXdlPerWavePerShuffle + 2 * n_xdl>{},
1339 c_warp_y_index_zeros),
1341 c_warp_y_lengths)));
1342 dest_up_tensor.set_y_sliced_thread_data(
1345 acc_tile_like_tensor.get_y_sliced_thread_data(
1347 sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
1348 epi_n_idx * NumNXdlPerWavePerShuffle + 2 * n_xdl + 1>{},
1349 c_warp_y_index_zeros),
1351 c_warp_y_lengths)));
1355 auto process_epi_tile = [&](
auto lds_stage,
auto epi_m,
auto epi_n) {
1358 LDSTileTensor gate_tensor, up_tensor;
1360 gate_up_epi_tile_idx_interleave_slice(
1361 gate_tensor, up_tensor, c_block_tile, epi_m, epi_n);
1362 auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n);
1363 auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n);
1364 auto epi_scale_n_up = epi_tile_idx_slice(scale_n_up_buffer, epi_m, epi_n);
1366 auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
1367 auto epi_exp_bias_up = epi_tile_idx_slice(exp_bias_up_buffer, epi_m, epi_n);
1372 gate_tensor.get_thread_buffer()[idx] *=
1373 epi_scale_m[idx] * epi_scale_n[idx];
1374 up_tensor.get_thread_buffer()[idx] *=
1375 epi_scale_m[idx] * epi_scale_n_up[idx];
1377 if constexpr(EnableBias)
1379 gate_tensor.get_thread_buffer()[idx] += epi_exp_bias[idx];
1380 up_tensor.get_thread_buffer()[idx] += epi_exp_bias_up[idx];
1382 lds_tile[lds_stage].get_thread_buffer().at(idx) =
1383 ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
1384 up_tensor.get_thread_buffer().at(idx));
1389 lds_tile[lds_stage].get_thread_buffer() =
1390 epi_tile_idx_slice(c_block_tile, epi_m, epi_n);
1391 auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n);
1392 auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n);
1393 auto epi_exp_weight = epi_tile_idx_slice(exp_weight_buffer, epi_m, epi_n);
1394 auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
1398 lds_tile[lds_stage].get_thread_buffer()[idx] *=
1399 epi_scale_m[idx] * epi_scale_n[idx];
1403 if constexpr(EnableBias)
1404 lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
1406 lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
1408 if constexpr(kind ==
1410 lds_tile[lds_stage].get_thread_buffer()[idx] =
1411 ActivationOp{}(lds_tile[lds_stage].get_thread_buffer()[idx]);
1416 constexpr
int NumMEpiTile = MRepeat / NumMXdlPerWavePerShuffle;
1417 constexpr
int MPerThread = TileEncodingPattern::Y2;
1422 auto c_coord = dram_tile_distribution.calculate_index();
1425 auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0;
1429 index_t scatter_token_id = fused_token & token_id_mask;
1430 c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.
NumTokens);
1433 scatter_token_id * kargs.
TopK + (fused_token >> token_id_offset);
1434 c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.
stride_C;
1444 constexpr
int read_stage = iAccess % 2;
1445 constexpr
int write_stage = read_stage ^ 1;
1449 constexpr
auto mIter =
number<idx_y_start.at(
number<0>{}) / MPerIterationShuffle>{};
1451 const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile[read_stage]);
1453 store_tile(in_lds_window, c_warptile_in_tensor_casted);
1455 if constexpr(iAccess < num_access - 1)
1458 constexpr
auto mIter_next =
1460 constexpr
auto nIter_next =
1470 auto c_scatter_tile_window =
1472 c_block_window.get_window_lengths(),
1473 c_block_window.get_window_origin(),
1474 dram_tile_distribution,
1475 c_scatter_offsets[mIter],
1476 c_scatter_valids[mIter]);
1479 decltype(c_block_window.get_bottom_tensor_view())::DstInMemOp ==
1481 c_scatter_tile_window.update(c_out_tensor);
1483 c_scatter_tile_window.store(c_out_tensor);
1485 if constexpr(iAccess != num_access - 1)
1487 constexpr
auto step = SFC::get_forward_step(iAccess);
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1690
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
@ thread_raked
Thread raked pattern.
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1697
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1684
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition: tensor_descriptor.hpp:371
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:422
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1066
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:837
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
MoeFlatmmKind
Definition: moe_flatmm_kernel.hpp:131
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:1086
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:486
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:206
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:495
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1037
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
index_t N
Definition: flatmm_kernel.hpp:173
const void * a_ptr
Definition: flatmm_kernel.hpp:164
index_t stride_B
Definition: flatmm_kernel.hpp:176
index_t stride_C
Definition: flatmm_kernel.hpp:181
index_t K
Definition: flatmm_kernel.hpp:174
const void * b_ptr
Definition: flatmm_kernel.hpp:165
index_t k_batch
Definition: flatmm_kernel.hpp:184
index_t stride_A
Definition: flatmm_kernel.hpp:175
void * e_ptr
Definition: flatmm_kernel.hpp:169
index_t M
Definition: flatmm_kernel.hpp:172
Definition: flatmm_kernel.hpp:33
Definition: moe_flatmm_kernel.hpp:21
ck_tile::index_t NumExperts
Definition: moe_flatmm_kernel.hpp:23
const void * p_sorted_expert_weights
Definition: moe_flatmm_kernel.hpp:28
const ck_tile::index_t * p_max_token_id
Definition: moe_flatmm_kernel.hpp:27
ck_tile::index_t NumTokens
Definition: moe_flatmm_kernel.hpp:22
const ck_tile::index_t * p_sorted_expert_ids
Definition: moe_flatmm_kernel.hpp:26
ExpertBias exp_bias
Definition: moe_flatmm_kernel.hpp:31
const ck_tile::index_t n_padded_zeros
Definition: moe_flatmm_kernel.hpp:29
const ck_tile::index_t * p_sorted_token_ids
Definition: moe_flatmm_kernel.hpp:25
const ck_tile::index_t k_padded_zeros
Definition: moe_flatmm_kernel.hpp:30
CK_TILE_HOST MoeFlatmmHostArgs(const ck_tile::index_t *p_sorted_token_ids_, const void *p_sorted_expert_weights_, const ck_tile::index_t *p_sorted_expert_ids_, const ck_tile::index_t *p_max_token_id_, const void *a_ptr_, const void *b_ptr_, void *c_ptr_, ck_tile::index_t NumTokens_, ck_tile::index_t NumExperts_, ck_tile::index_t TopK_, ck_tile::index_t k_batch_, ck_tile::index_t M_, ck_tile::index_t N_, ck_tile::index_t K_, ck_tile::index_t stride_A_, ck_tile::index_t stride_B_, ck_tile::index_t stride_C_, ck_tile::index_t n_padded_zeros_=0, ck_tile::index_t k_padded_zeros_=0, ScaleM scale_m_={}, ScaleN scale_n_={}, ExpertBias exp_bias_={})
Definition: moe_flatmm_kernel.hpp:80
CK_TILE_HOST MoeFlatmmHostArgs() noexcept=default
ck_tile::index_t TopK
Definition: moe_flatmm_kernel.hpp:24
Definition: moe_flatmm_kernel.hpp:273
ck_tile::index_t K
Definition: moe_flatmm_kernel.hpp:285
ExpertBias exp_bias
Definition: moe_flatmm_kernel.hpp:294
ck_tile::index_t stride_B
Definition: moe_flatmm_kernel.hpp:287
ScaleM scale_m
Definition: moe_flatmm_kernel.hpp:292
ck_tile::index_t k_padded_zeros
Definition: moe_flatmm_kernel.hpp:291
const void * b_ptr
Definition: moe_flatmm_kernel.hpp:279
ck_tile::index_t stride_A
Definition: moe_flatmm_kernel.hpp:286
ck_tile::index_t k_batch
Definition: moe_flatmm_kernel.hpp:289
ck_tile::index_t stride_C
Definition: moe_flatmm_kernel.hpp:288
void * e_ptr
Definition: moe_flatmm_kernel.hpp:280
const ck_tile::index_t * p_max_token_id
Definition: moe_flatmm_kernel.hpp:276
ScaleN scale_n
Definition: moe_flatmm_kernel.hpp:293
ck_tile::index_t NumTokens
Definition: moe_flatmm_kernel.hpp:281
ck_tile::index_t M
Definition: moe_flatmm_kernel.hpp:283
ck_tile::index_t n_padded_zeros
Definition: moe_flatmm_kernel.hpp:290
ck_tile::index_t TopK
Definition: moe_flatmm_kernel.hpp:282
const ck_tile::index_t * p_sorted_token_ids
Definition: moe_flatmm_kernel.hpp:274
const ck_tile::index_t * p_sorted_expert_ids
Definition: moe_flatmm_kernel.hpp:275
const void * a_ptr
Definition: moe_flatmm_kernel.hpp:278
ck_tile::index_t N
Definition: moe_flatmm_kernel.hpp:284
const void * p_sorted_expert_weights
Definition: moe_flatmm_kernel.hpp:277
Definition: moe_flatmm_kernel.hpp:384
index_t splitted_k
Definition: moe_flatmm_kernel.hpp:429
index_t b_k_split_offset
Definition: moe_flatmm_kernel.hpp:428
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: moe_flatmm_kernel.hpp:386
index_t a_k_split_offset
Definition: moe_flatmm_kernel.hpp:427
Definition: moe_flatmm_kernel.hpp:193
static constexpr int OutputNPerBlock
Definition: moe_flatmm_kernel.hpp:244
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: moe_flatmm_kernel.hpp:210
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: moe_flatmm_kernel.hpp:203
static constexpr index_t NumDTensor
Definition: moe_flatmm_kernel.hpp:215
static constexpr bool AQUANT_Pipeline
Definition: moe_flatmm_kernel.hpp:248
float AccDataType
Definition: moe_flatmm_kernel.hpp:212
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: moe_flatmm_kernel.hpp:197
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: moe_flatmm_kernel.hpp:202
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: moe_flatmm_kernel.hpp:378
static constexpr auto GridSize(const MoeFlatmmKernelArgs &kargs)
Definition: moe_flatmm_kernel.hpp:339
static constexpr auto I1
Definition: moe_flatmm_kernel.hpp:218
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: moe_flatmm_kernel.hpp:194
static constexpr bool BMXFP4_Pipeline
Definition: moe_flatmm_kernel.hpp:251
static constexpr auto I3
Definition: moe_flatmm_kernel.hpp:220
static constexpr index_t kBlockSize
Definition: moe_flatmm_kernel.hpp:204
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: moe_flatmm_kernel.hpp:374
static constexpr bool IsInputGemm
Definition: moe_flatmm_kernel.hpp:226
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: moe_flatmm_kernel.hpp:778
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: moe_flatmm_kernel.hpp:199
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: moe_flatmm_kernel.hpp:198
static constexpr int MXFP4N_Pack
Definition: moe_flatmm_kernel.hpp:260
static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: moe_flatmm_kernel.hpp:334
static constexpr bool UsePersistentKernel
Definition: moe_flatmm_kernel.hpp:205
FusedActivation ActivationOp
Definition: moe_flatmm_kernel.hpp:213
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: moe_flatmm_kernel.hpp:208
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: moe_flatmm_kernel.hpp:200
static constexpr bool IsBShuffled
Definition: moe_flatmm_kernel.hpp:229
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: moe_flatmm_kernel.hpp:207
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: moe_flatmm_kernel.hpp:201
static constexpr index_t kMPerBlock
Definition: moe_flatmm_kernel.hpp:232
static constexpr index_t MWave
Definition: moe_flatmm_kernel.hpp:234
static constexpr index_t KPerXdl
Definition: moe_flatmm_kernel.hpp:238
static constexpr auto BlockSize() -> dim3
Definition: moe_flatmm_kernel.hpp:332
static constexpr bool IsGateUp
Definition: moe_flatmm_kernel.hpp:227
static constexpr index_t kNPerBlock
Definition: moe_flatmm_kernel.hpp:233
static CK_TILE_HOST const std::string GetName()
Definition: moe_flatmm_kernel.hpp:326
static constexpr CK_TILE_HOST auto MakeKernelArgs(const MoeFlatmmHostArgs< ScaleM, ScaleN, ExpertBias > &hostArgs)
Definition: moe_flatmm_kernel.hpp:301
static constexpr index_t NPerXdl
Definition: moe_flatmm_kernel.hpp:237
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, EDataType *e_ptr, [[maybe_unused]] const AccDataType *exp_weight_ptr, [[maybe_unused]] const int expert_id, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: moe_flatmm_kernel.hpp:591
static constexpr index_t kNPerIteration
Definition: moe_flatmm_kernel.hpp:241
static constexpr index_t kMPerIteration
Definition: moe_flatmm_kernel.hpp:240
static constexpr int WeightPackedSize
Definition: moe_flatmm_kernel.hpp:267
static constexpr auto I0
Definition: moe_flatmm_kernel.hpp:217
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: moe_flatmm_kernel.hpp:433
static constexpr index_t isCTransposed
Definition: moe_flatmm_kernel.hpp:239
static constexpr bool IsGemm1SplitK
Definition: moe_flatmm_kernel.hpp:228
static constexpr int MXFP4M_Pack
Definition: moe_flatmm_kernel.hpp:259
static constexpr int K_Pack
Definition: moe_flatmm_kernel.hpp:265
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const
Definition: moe_flatmm_kernel.hpp:902
static constexpr int N_Pack
Definition: moe_flatmm_kernel.hpp:264
static constexpr bool MXF8F6F4MFMA
Definition: moe_flatmm_kernel.hpp:253
static constexpr int MXFP4K_Pack
Definition: moe_flatmm_kernel.hpp:261
static constexpr int M_Pack
Definition: moe_flatmm_kernel.hpp:263
static constexpr index_t kNRepeat
Definition: moe_flatmm_kernel.hpp:242
static constexpr index_t MPerXdl
Definition: moe_flatmm_kernel.hpp:236
static constexpr auto I4
Definition: moe_flatmm_kernel.hpp:221
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs, index_t iM, index_t iN) const
Definition: moe_flatmm_kernel.hpp:917
static constexpr index_t NWave
Definition: moe_flatmm_kernel.hpp:235
static constexpr auto I2
Definition: moe_flatmm_kernel.hpp:219
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: moe_flatmm_kernel.hpp:195
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, [[maybe_unused]] const index_t coord_m, const index_t coord_n)
Definition: moe_flatmm_kernel.hpp:821
Definition: flatmm_kernel.hpp:190
ScaleM scale_m
Definition: flatmm_kernel.hpp:222
ScaleN scale_n
Definition: flatmm_kernel.hpp:223
Definition: integral_constant.hpp:13
Definition: unary_element_wise_operation.hpp:1026
Definition: type_traits.hpp:115
Definition: moe_flatmm_kernel.hpp:141
CK_TILE_HOST_DEVICE T operator()(T gate, T linear=1) const
Definition: moe_flatmm_kernel.hpp:143
Definition: moe_flatmm_kernel.hpp:151
const float alpha
Definition: moe_flatmm_kernel.hpp:152
const float limit
Definition: moe_flatmm_kernel.hpp:153
CK_TILE_HOST_DEVICE Swiglu(float alpha_=1.702f, float limit_=7.0f)
Definition: moe_flatmm_kernel.hpp:156
CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
Definition: moe_flatmm_kernel.hpp:162
Definition: numeric.hpp:81
Definition: sequence.hpp:49
Definition: space_filling_curve.hpp:20
Definition: functional.hpp:43
Class creating 2D static tile distribution with different load/store patterns.
Definition: static_encoding_pattern.hpp:130
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192