17 template <
class ScaleM = FlatmmScalePointer<-1>,
18 class ScaleN = FlatmmScalePointer<-1>,
19 class ExpertBias = FlatmmScalePointer<-1>>
36 const
void* p_sorted_expert_weights_,
54 ExpertBias exp_bias_ = {})
56 p_sorted_expert_weights_,
81 const void* p_sorted_expert_weights_,
100 ScaleN scale_n_ = {},
101 ExpertBias exp_bias_ = {})
102 : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>(a_ptr_,
141 template <
typename T>
145 return gate * linear;
155 Swiglu(
float alpha_ = 1.702f,
float limit_ = 7.0f)
160 template <
typename T>
163 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
164 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
165 std::is_same_v<T, int32_t>,
166 "Data type is not supported by this operation!");
168 constexpr T one = type_convert<T>(1);
173 if constexpr(std::is_same_v<T, float>)
175 return gate * __builtin_amdgcn_rcpf(one +
ck_tile::exp(
alpha * -gate)) * (linear + 1);
186 template <
typename TilePartitioner_,
187 typename FlatmmPipeline_,
188 typename EpiloguePipeline_,
190 typename FusedActivation = moe::MoeSilu>
221 static_assert(DsLayout::size() == DsDataType::size(),
222 "The size of DsLayout and DsDataType should be the same");
241 IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
313 '_',
"moe_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
320 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
322 template <
class MoeFlatmmKernelArgs>
327 hipDeviceProp_t prop;
331 int dync_smem_size = 0;
332 int maxActiveBlocksPerCU = 0;
334 [[maybe_unused]]
auto e = hipGetDeviceProperties(&prop, deviceId);
336 e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
337 &maxActiveBlocksPerCU,
338 reinterpret_cast<void*
>(kentry<1, MoeFlatmmKernel, MoeFlatmmKernelArgs>),
342 const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
343 const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.
M, kargs.
N);
350 return dim3(
min(persistent_block_size, total_work_tile_cnt), 1, kargs.
k_batch);
354 return dim3(TilePartitioner::GridSize(kargs.
M, kargs.
N), 1, kargs.
k_batch);
360 return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
364 return FlatmmPipeline::GetSmemSize();
369 template <
class KernelArgs>
372 constexpr
auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(
number<2>{});
373 const index_t K_t = kargs.k_batch * K1;
374 const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
376 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
380 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
385 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
389 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
394 if(k_id <
static_cast<uint32_t>(kargs.k_batch - 1))
400 splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
409 template <
typename KernelArgs>
412 if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
415 if(kargs.k_batch != 1)
417 std::cerr <<
"Conditions not met for Kbatch >1 !" << std::endl;
423 if(kargs.k_batch != 1)
425 std::cerr <<
"Persistent mode doesn't support Kbatch >1 !" << std::endl;
430 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
432 if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK ==
false)
434 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
439 if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
441 std::cerr <<
"K is not a multiple of vector load size for A tensor!" << std::endl;
447 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
449 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
454 if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
456 std::cerr <<
"M is not a multiple of vector load size for A tensor!" << std::endl;
461 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
470 if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
472 std::cerr <<
"N is not a multiple of vector load size for B tensor!" << std::endl;
478 if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK ==
false)
480 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
485 if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
487 std::cerr <<
"K is not a multiple of vector load size for B tensor!" << std::endl;
492 bool DTesnorIsValid = {
true};
495 if(std::is_same_v<DiLayout, ELayout> ==
false)
497 DTesnorIsValid =
false;
499 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
501 if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
503 CK_TILE_ERROR(
"Can't support N for tensor D that is not a multiple of "
504 "NPerBlock without padding!");
505 DTesnorIsValid =
false;
507 if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
509 CK_TILE_ERROR(
"N is not a multiple of vector load size for D tensor!");
510 DTesnorIsValid =
false;
515 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
517 CK_TILE_ERROR(
"Can't support M for tensor D that is not a multiple of "
518 "MPerBlock without padding!");
520 DTesnorIsValid =
false;
522 if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
524 CK_TILE_ERROR(
"M is not a multiple of vector load size for D tensor!");
525 DTesnorIsValid =
false;
530 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
532 if(kargs.stride_C % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
534 std::cerr <<
"Can't support N that is not a multiple of NPerBlock"
539 if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
541 std::cerr <<
"N is not a multiple of vector load size for C tensor!" << std::endl;
547 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
549 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
554 if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
556 std::cerr <<
"M is not a multiple of vector load size for C tensor!" << std::endl;
560 return DTesnorIsValid;
563 template <memory_operation_enum DstInMemOp =
IsInputGemm ? memory_operation_enum::set
570 [[maybe_unused]]
const AccDataType* exp_weight_ptr,
572 const KernelArgs& kargs,
575 const auto& a_tensor_view = [&]() {
576 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
578 return make_naive_tensor_view<address_space_enum::global>(
583 number<FlatmmPipeline::GetVectorSizeA()>{},
588 return make_naive_tensor_view<address_space_enum::global>(
591 IsInputGemm ? kargs.NumTokens : kargs.NumTokens * kargs.TopK),
593 number<FlatmmPipeline::GetVectorSizeA()>{},
598 index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(
I1);
599 index_t kFlatN = kargs.N * kargs.K / kFlatK;
601 const auto& b_flat_tensor_view = [&]() {
602 return make_naive_tensor_view<address_space_enum::global>(
606 number<FlatmmPipeline::GetVectorSizeB()>{},
611 const auto& c_tensor_view = [&]() {
612 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
614 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
619 number<EpiloguePipeline::GetVectorSizeC()>{},
624 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
634 auto scale_n = kargs.scale_n;
635 constexpr
int GranularityK = decltype(scale_n)::GranularityK;
637 index_t scale_k = GranularityK == 0 ? 1 : (kargs.K + GranularityK - 1) / GranularityK;
638 index_t FlatScaleK = scale_k *
N_Pack * BlockGemmShape::WarpTile::at(
I1);
639 index_t FlatScaleN = kargs.N /
N_Pack / BlockGemmShape::WarpTile::at(
I1);
641 using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
643 const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
644 reinterpret_cast<const ScaleType*
>(scale_n.ptr) + expert_id * kargs.N * scale_k,
650 return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view, scale_b_flat_view);
653 template <
typename TensorView>
656 const auto& a_pad_view = [&]() {
657 const auto& a_tensor_view = views.at(
I0);
658 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
675 const auto& c_pad_view = [&]() {
676 const auto& c_tensor_view = views.at(
I2);
677 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
693 return make_tuple(a_pad_view, views.at(
I1), c_pad_view, views.at(
I3));
696 template <
typename PadView>
698 [[maybe_unused]]
const index_t coord_m,
701 const auto& a_pad_view = views.at(
number<0>{});
702 const auto& b_flat_pad_view = views.at(
number<1>{});
703 const auto& c_pad_view = views.at(
number<2>{});
705 const auto& a_block_window = [&]() {
706 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
724 const auto& b_flat_block_window =
728 {
static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(
I1) /
729 (isNonInterleaveGateUp ? 1 : 2)),
732 const int output_N_offset =
IsGateUp ? coord_n / 2 : coord_n;
740 constexpr
int GranularityK = 32;
741 constexpr
int XDLPerLoadScaleB =
744 auto scale_block_window =
748 XDLPerLoadScaleB / GranularityK>{}),
749 {coord_n / BlockGemmShape::WarpTile::at(
I1) /
N_Pack, 0});
751 return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
754 template <
class MoeFlatmmKernelArgs>
757 int partition_idx = blockIdx.x;
758 int total_work_tile_cnt = TilePartitioner::GridSize(kargs.
M, kargs.
N);
761 const auto [block_offset_m, block_offset_n] =
764 this->
operator()(kargs, block_offset_m, block_offset_n);
765 partition_idx += gridDim.x;
769 template <
class MoeFlatmmKernelArgs>
774 const index_t coord_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
775 const index_t coord_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
783 constexpr
auto a_dram_dist = FlatmmPipeline::GetADramTileDistribution();
784 const auto a_coord = a_dram_dist.calculate_index();
790 constexpr
index_t token_id_offset = 24;
791 constexpr
index_t token_id_mask = (1 << token_id_offset) - 1;
793 auto row_to_token_idx = [&](
auto row_idx) {
796 index_t gather_token_id = fused_token & token_id_mask;
799 gather_token_id = gather_token_id * kargs.
TopK + (fused_token >> token_id_offset);
801 return gather_token_id;
804 if(coord_m >= max_token_id)
809 coord_m + m0 * (TilePartitioner::MPerBlock / DramMRepeat) + a_coord[
I0];
810 index_t gather_token_id = row_to_token_idx(row_idx);
811 a_offsets[m0] = std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
831 a_ptr, b_flat_ptr, e_ptr, exp_weight_ptr, expert_id, kargs, splitk_batch_offset);
836 const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k);
839 const auto& a_block_window = gemm_tile_windows.at(
I0);
840 const auto& b_block_window = gemm_tile_windows.at(
I1);
841 const auto& scale_block_window = gemm_tile_windows.at(
I3);
843 auto a_gather_block_tile =
845 a_block_window.get_window_lengths(),
846 a_block_window.get_window_origin(),
850 auto c_block_tile = [&] {
874 auto& c_block_window = gemm_tile_windows.at(
number<2>{});
878 using EpiProblem =
typename EpiloguePipeline::Problem;
879 using ODataType =
typename EpiloguePipeline::ODataType;
880 using CWarpDstr =
typename EpiloguePipeline::CWarpDstr;
882 constexpr
index_t NumMXdlPerWavePerShuffle = EpiloguePipeline::NumMXdlPerWavePerShuffle;
883 constexpr
index_t NumNXdlPerWavePerShuffle = EpiloguePipeline::NumNXdlPerWavePerShuffle;
884 constexpr
index_t MPerIterationShuffle = EpiloguePipeline::MPerIterationShuffle;
885 constexpr
index_t NPerIterationShuffle = EpiloguePipeline::NPerIterationShuffle;
887 constexpr
index_t MRepeat = EpiloguePipeline::MRepeat;
888 constexpr
index_t NRepeat = EpiloguePipeline::NRepeat;
891 [[maybe_unused]] constexpr
index_t EpiVectorSizeC = EpiloguePipeline::GetVectorSizeC();
892 [[maybe_unused]] constexpr
index_t BlockedXDLN_PerWarp =
893 EpiloguePipeline::BlockedXDLN_PerWarp;
895 static_assert(!
IsGateUp || NumNXdlPerWavePerShuffle % 2 == 0);
897 constexpr
index_t OutputNumNXdlPerWavePerShuffle =
898 IsGateUp ? NumNXdlPerWavePerShuffle / 2 : NumNXdlPerWavePerShuffle;
899 constexpr
index_t LDS_NPerIterationShuffle =
900 IsGateUp ? NPerIterationShuffle / 2 : NPerIterationShuffle;
907 auto o_lds_block = make_tensor_view<address_space_enum::lds>(
908 reinterpret_cast<ODataType*
>(smem_ptr_ping), lds_block_desc);
910 constexpr
int ScaleGranularityM = decltype(kargs.
scale_m)::GranularityMN;
911 constexpr
int ScaleGranularityN = decltype(kargs.
scale_n)::GranularityMN;
913 constexpr
index_t scale_stride_m = ScaleGranularityM == 0 ? 0
915 constexpr
index_t scale_stride_n = ScaleGranularityN == 0 ? 0
918 auto output_acc_tile_distr =
927 typename CWarpDstr::DstrEncode{}));
929 const auto scale_m_coord =
930 output_acc_tile_distr.calculate_index();
936 constexpr
index_t ScaleMRepeat = MRepeat * kM0 * kM2;
944 coord_m + mIter *
MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[
I0];
946 row_to_token_idx(row_idx);
951 constexpr
int DynamicTileOffsetFlag = 0;
953 constexpr
bool EnableBias = decltype(kargs.
exp_bias)::GranularityMN != -1;
955 auto permute_tensor_view = [&](
auto naive_view,
auto is_needed_to_permute_N_PACK) {
956 if constexpr(!is_needed_to_permute_N_PACK)
987 auto scale_m_window =
997 output_acc_tile_distr,
1001 make_naive_tensor_view<address_space_enum::global>(
1002 kargs.
scale_n.ptr + expert_id * kargs.
N,
1005 number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
1010 : TilePartitioner::NPerBlock > {}),
1011 {0,
IsGateUp ? coord_n / 2 : coord_n},
1012 output_acc_tile_distr);
1015 make_naive_tensor_view<address_space_enum::global>(
1016 kargs.
scale_n.ptr + expert_id * kargs.
N + kargs.
N / 2,
1019 number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
1022 number<TilePartitioner::NPerBlock / 2>{}),
1024 output_acc_tile_distr);
1026 auto exp_bias_view = make_naive_tensor_view<address_space_enum::global>(
1027 kargs.
exp_bias.ptr + expert_id * kargs.
N,
1030 number<FlatmmPipeline::GetVectorSizeB()>{},
1037 : TilePartitioner::NPerBlock > {}),
1038 {0,
IsGateUp ? coord_n / 2 : coord_n},
1039 output_acc_tile_distr);
1041 auto exp_bias_up_window =
1043 kargs.
exp_bias.ptr + expert_id * kargs.
N + kargs.
N / 2,
1046 number<FlatmmPipeline::GetVectorSizeB()>{},
1049 number<TilePartitioner::NPerBlock / 2>{}),
1051 output_acc_tile_distr);
1053 auto exp_weight_window =
1058 number<FlatmmPipeline::GetVectorSizeA()>{},
1063 output_acc_tile_distr);
1065 using ScaleMBuffer = decltype(
load_tile(scale_m_window));
1066 using ScaleNBuffer = decltype(
load_tile(scale_n_window));
1067 using ExpBiasBuffer = decltype(
load_tile(exp_bias_window));
1068 using ExpWeightBuffer = decltype(
load_tile(exp_weight_window));
1070 ScaleMBuffer scale_m_buffer;
1071 ScaleNBuffer scale_n_buffer, scale_n_up_buffer;
1073 ExpBiasBuffer exp_bias_buffer, exp_bias_up_buffer;
1074 ExpWeightBuffer exp_weight_buffer;
1078 scale_m_window.load(scale_m_buffer);
1079 scale_n_buffer =
load_tile(scale_n_window);
1081 scale_n_up_buffer =
load_tile(scale_n_up_window);
1084 if constexpr(EnableBias)
1086 exp_bias_buffer =
load_tile(exp_bias_window);
1088 exp_bias_up_buffer =
load_tile(exp_bias_up_window);
1091 exp_weight_buffer =
load_tile(exp_weight_window);
1107 constexpr
index_t num_access = SFC::get_num_of_access();
1109 static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
1110 "Currently, the CShuffle EpiloguePipeline only supports the Row Major "
1115 MPerIterationShuffle,
1116 LDS_NPerIterationShuffle,
1119 EpiProblem::kNumWaveGroups>;
1121 constexpr
auto dram_tile_distribution =
1122 TileEncodingPattern::make_2d_static_tile_distribution();
1124 constexpr
auto LdsTileDistr = [&] {
1137 typename CWarpDstr::DstrEncode{}));
1140 EpiloguePipeline::MakeLdsDistributionEncode());
1143 using LDSTileTensor =
1144 decltype(make_static_distributed_tensor<AccDataType>(LdsTileDistr));
1145 LDSTileTensor lds_tile[2];
1147 constexpr
auto c_warp_y_lengths =
1148 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
1150 constexpr
int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle *
1151 OutputNumNXdlPerWavePerShuffle;
1153 auto epi_tile_idx_slice =
1154 [&](
const auto& acc_tile_like_tensor,
auto epi_m_idx,
auto epi_n_idx) {
1155 return acc_tile_like_tensor.get_y_sliced_thread_data(
1157 epi_n_idx * OutputNumNXdlPerWavePerShuffle>{},
1158 c_warp_y_index_zeros),
1164 auto gate_up_epi_tile_idx_interleave_slice = [&](
auto& dest_gate_tensor,
1165 auto& dest_up_tensor,
1166 const auto& acc_tile_like_tensor,
1170 dest_gate_tensor.set_y_sliced_thread_data(
1173 acc_tile_like_tensor.get_y_sliced_thread_data(
1175 sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
1176 epi_n_idx * NumNXdlPerWavePerShuffle + 2 * n_xdl>{},
1177 c_warp_y_index_zeros),
1179 c_warp_y_lengths)));
1180 dest_up_tensor.set_y_sliced_thread_data(
1183 acc_tile_like_tensor.get_y_sliced_thread_data(
1185 sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
1186 epi_n_idx * NumNXdlPerWavePerShuffle + 2 * n_xdl + 1>{},
1187 c_warp_y_index_zeros),
1189 c_warp_y_lengths)));
1193 auto process_epi_tile = [&](
auto lds_stage,
auto epi_m,
auto epi_n) {
1196 LDSTileTensor gate_tensor, up_tensor;
1198 gate_up_epi_tile_idx_interleave_slice(
1199 gate_tensor, up_tensor, c_block_tile, epi_m, epi_n);
1200 auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n);
1201 auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n);
1202 auto epi_scale_n_up = epi_tile_idx_slice(scale_n_up_buffer, epi_m, epi_n);
1204 auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
1205 auto epi_exp_bias_up = epi_tile_idx_slice(exp_bias_up_buffer, epi_m, epi_n);
1210 gate_tensor.get_thread_buffer()[idx] *=
1211 epi_scale_m[idx] * epi_scale_n[idx];
1212 up_tensor.get_thread_buffer()[idx] *=
1213 epi_scale_m[idx] * epi_scale_n_up[idx];
1215 if constexpr(EnableBias)
1217 gate_tensor.get_thread_buffer()[idx] += epi_exp_bias[idx];
1218 up_tensor.get_thread_buffer()[idx] += epi_exp_bias_up[idx];
1220 lds_tile[lds_stage].get_thread_buffer().at(idx) =
1221 ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
1222 up_tensor.get_thread_buffer().at(idx));
1227 lds_tile[lds_stage].get_thread_buffer() =
1228 epi_tile_idx_slice(c_block_tile, epi_m, epi_n);
1229 auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n);
1230 auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n);
1231 auto epi_exp_weight = epi_tile_idx_slice(exp_weight_buffer, epi_m, epi_n);
1232 auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
1236 lds_tile[lds_stage].get_thread_buffer()[idx] *=
1237 epi_scale_m[idx] * epi_scale_n[idx];
1238 if constexpr(EnableBias)
1239 lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
1241 lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
1243 lds_tile[lds_stage].get_thread_buffer()[idx] =
1244 ActivationOp{}(lds_tile[lds_stage].get_thread_buffer()[idx]);
1249 constexpr
int NumMEpiTile = MRepeat / NumMXdlPerWavePerShuffle;
1250 constexpr
int MPerThread = TileEncodingPattern::Y2;
1253 auto c_coord = dram_tile_distribution.calculate_index();
1256 auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0;
1260 index_t scatter_token_id = fused_token & token_id_mask;
1263 scatter_token_id * kargs.
TopK + (fused_token >> token_id_offset);
1264 c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.
stride_C;
1274 constexpr
int read_stage = iAccess % 2;
1275 constexpr
int write_stage = read_stage ^ 1;
1279 constexpr
auto mIter =
number<idx_y_start.at(
number<0>{}) / MPerIterationShuffle>{};
1281 const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile[read_stage]);
1283 store_tile(in_lds_window, c_warptile_in_tensor_casted);
1285 if constexpr(iAccess < num_access - 1)
1288 constexpr
auto mIter_next =
1290 constexpr
auto nIter_next =
1300 auto c_scatter_tile_window =
1302 c_block_window.get_window_lengths(),
1303 c_block_window.get_window_origin(),
1304 dram_tile_distribution,
1305 c_scatter_offsets[mIter]);
1309 c_scatter_tile_window.update(c_out_tensor);
1311 c_scatter_tile_window.store(c_out_tensor);
1313 if constexpr(iAccess != num_access - 1)
1315 constexpr
auto step = SFC::get_forward_step(iAccess);
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:245
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
@ thread_raked
Thread raked pattern.
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1609
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:419
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1055
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:826
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
MoeFlatmmKind
Definition: moe_flatmm_kernel.hpp:131
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:906
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1026
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
index_t N
Definition: flatmm_kernel.hpp:170
const void * a_ptr
Definition: flatmm_kernel.hpp:161
index_t stride_B
Definition: flatmm_kernel.hpp:173
index_t stride_C
Definition: flatmm_kernel.hpp:178
index_t K
Definition: flatmm_kernel.hpp:171
const void * b_ptr
Definition: flatmm_kernel.hpp:162
index_t k_batch
Definition: flatmm_kernel.hpp:181
index_t stride_A
Definition: flatmm_kernel.hpp:172
void * e_ptr
Definition: flatmm_kernel.hpp:166
index_t M
Definition: flatmm_kernel.hpp:169
Definition: flatmm_kernel.hpp:33
Definition: moe_flatmm_kernel.hpp:21
ck_tile::index_t NumExperts
Definition: moe_flatmm_kernel.hpp:23
const void * p_sorted_expert_weights
Definition: moe_flatmm_kernel.hpp:28
const ck_tile::index_t * p_max_token_id
Definition: moe_flatmm_kernel.hpp:27
ck_tile::index_t NumTokens
Definition: moe_flatmm_kernel.hpp:22
const ck_tile::index_t * p_sorted_expert_ids
Definition: moe_flatmm_kernel.hpp:26
ExpertBias exp_bias
Definition: moe_flatmm_kernel.hpp:31
const ck_tile::index_t n_padded_zeros
Definition: moe_flatmm_kernel.hpp:29
const ck_tile::index_t * p_sorted_token_ids
Definition: moe_flatmm_kernel.hpp:25
const ck_tile::index_t k_padded_zeros
Definition: moe_flatmm_kernel.hpp:30
CK_TILE_HOST MoeFlatmmHostArgs(const ck_tile::index_t *p_sorted_token_ids_, const void *p_sorted_expert_weights_, const ck_tile::index_t *p_sorted_expert_ids_, const ck_tile::index_t *p_max_token_id_, const void *a_ptr_, const void *b_ptr_, void *c_ptr_, ck_tile::index_t NumTokens_, ck_tile::index_t NumExperts_, ck_tile::index_t TopK_, ck_tile::index_t k_batch_, ck_tile::index_t M_, ck_tile::index_t N_, ck_tile::index_t K_, ck_tile::index_t stride_A_, ck_tile::index_t stride_B_, ck_tile::index_t stride_C_, ck_tile::index_t n_padded_zeros_=0, ck_tile::index_t k_padded_zeros_=0, ScaleM scale_m_={}, ScaleN scale_n_={}, ExpertBias exp_bias_={})
Definition: moe_flatmm_kernel.hpp:80
CK_TILE_HOST MoeFlatmmHostArgs() noexcept=default
ck_tile::index_t TopK
Definition: moe_flatmm_kernel.hpp:24
Definition: moe_flatmm_kernel.hpp:257
ck_tile::index_t K
Definition: moe_flatmm_kernel.hpp:269
ExpertBias exp_bias
Definition: moe_flatmm_kernel.hpp:278
ck_tile::index_t stride_B
Definition: moe_flatmm_kernel.hpp:271
ScaleM scale_m
Definition: moe_flatmm_kernel.hpp:276
ck_tile::index_t k_padded_zeros
Definition: moe_flatmm_kernel.hpp:275
const void * b_ptr
Definition: moe_flatmm_kernel.hpp:263
ck_tile::index_t stride_A
Definition: moe_flatmm_kernel.hpp:270
ck_tile::index_t k_batch
Definition: moe_flatmm_kernel.hpp:273
ck_tile::index_t stride_C
Definition: moe_flatmm_kernel.hpp:272
void * e_ptr
Definition: moe_flatmm_kernel.hpp:264
const ck_tile::index_t * p_max_token_id
Definition: moe_flatmm_kernel.hpp:260
ScaleN scale_n
Definition: moe_flatmm_kernel.hpp:277
ck_tile::index_t NumTokens
Definition: moe_flatmm_kernel.hpp:265
ck_tile::index_t M
Definition: moe_flatmm_kernel.hpp:267
ck_tile::index_t n_padded_zeros
Definition: moe_flatmm_kernel.hpp:274
ck_tile::index_t TopK
Definition: moe_flatmm_kernel.hpp:266
const ck_tile::index_t * p_sorted_token_ids
Definition: moe_flatmm_kernel.hpp:258
const ck_tile::index_t * p_sorted_expert_ids
Definition: moe_flatmm_kernel.hpp:259
const void * a_ptr
Definition: moe_flatmm_kernel.hpp:262
ck_tile::index_t N
Definition: moe_flatmm_kernel.hpp:268
const void * p_sorted_expert_weights
Definition: moe_flatmm_kernel.hpp:261
Definition: moe_flatmm_kernel.hpp:368
index_t splitted_k
Definition: moe_flatmm_kernel.hpp:406
index_t b_k_split_offset
Definition: moe_flatmm_kernel.hpp:405
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: moe_flatmm_kernel.hpp:370
index_t a_k_split_offset
Definition: moe_flatmm_kernel.hpp:404
Definition: moe_flatmm_kernel.hpp:192
static constexpr int OutputNPerBlock
Definition: moe_flatmm_kernel.hpp:240
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: moe_flatmm_kernel.hpp:209
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: moe_flatmm_kernel.hpp:202
static constexpr index_t NumDTensor
Definition: moe_flatmm_kernel.hpp:214
float AccDataType
Definition: moe_flatmm_kernel.hpp:211
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: moe_flatmm_kernel.hpp:196
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: moe_flatmm_kernel.hpp:201
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: moe_flatmm_kernel.hpp:362
static constexpr auto GridSize(const MoeFlatmmKernelArgs &kargs)
Definition: moe_flatmm_kernel.hpp:323
static constexpr auto I1
Definition: moe_flatmm_kernel.hpp:217
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: moe_flatmm_kernel.hpp:193
static constexpr auto I3
Definition: moe_flatmm_kernel.hpp:219
static constexpr index_t kBlockSize
Definition: moe_flatmm_kernel.hpp:203
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: moe_flatmm_kernel.hpp:358
static constexpr bool IsInputGemm
Definition: moe_flatmm_kernel.hpp:224
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: moe_flatmm_kernel.hpp:654
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: moe_flatmm_kernel.hpp:198
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: moe_flatmm_kernel.hpp:197
static constexpr int MXFP4N_Pack
Definition: moe_flatmm_kernel.hpp:245
static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: moe_flatmm_kernel.hpp:318
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, EDataType *e_ptr, [[maybe_unused]] const AccDataType *exp_weight_ptr, const int expert_id, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: moe_flatmm_kernel.hpp:567
static constexpr bool UsePersistentKernel
Definition: moe_flatmm_kernel.hpp:204
FusedActivation ActivationOp
Definition: moe_flatmm_kernel.hpp:212
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: moe_flatmm_kernel.hpp:207
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: moe_flatmm_kernel.hpp:199
static constexpr bool MXFP4_Pipeline
Definition: moe_flatmm_kernel.hpp:244
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: moe_flatmm_kernel.hpp:206
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: moe_flatmm_kernel.hpp:200
static constexpr index_t kMPerBlock
Definition: moe_flatmm_kernel.hpp:228
static constexpr index_t MWave
Definition: moe_flatmm_kernel.hpp:230
static constexpr index_t KPerXdl
Definition: moe_flatmm_kernel.hpp:234
static constexpr auto BlockSize() -> dim3
Definition: moe_flatmm_kernel.hpp:316
static constexpr bool IsGateUp
Definition: moe_flatmm_kernel.hpp:225
static constexpr index_t kNPerBlock
Definition: moe_flatmm_kernel.hpp:229
static CK_TILE_HOST const std::string GetName()
Definition: moe_flatmm_kernel.hpp:310
static constexpr CK_TILE_HOST auto MakeKernelArgs(const MoeFlatmmHostArgs< ScaleM, ScaleN, ExpertBias > &hostArgs)
Definition: moe_flatmm_kernel.hpp:285
static constexpr index_t NPerXdl
Definition: moe_flatmm_kernel.hpp:233
static constexpr index_t kNPerIteration
Definition: moe_flatmm_kernel.hpp:237
static constexpr index_t kMPerIteration
Definition: moe_flatmm_kernel.hpp:236
static constexpr int WeightPackedSize
Definition: moe_flatmm_kernel.hpp:251
static constexpr auto I0
Definition: moe_flatmm_kernel.hpp:216
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: moe_flatmm_kernel.hpp:410
static constexpr index_t isCTransposed
Definition: moe_flatmm_kernel.hpp:235
static constexpr int K_Pack
Definition: moe_flatmm_kernel.hpp:249
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const
Definition: moe_flatmm_kernel.hpp:755
static constexpr int N_Pack
Definition: moe_flatmm_kernel.hpp:248
static constexpr int MXFP4K_Pack
Definition: moe_flatmm_kernel.hpp:246
static constexpr index_t kNRepeat
Definition: moe_flatmm_kernel.hpp:238
static constexpr index_t MPerXdl
Definition: moe_flatmm_kernel.hpp:232
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs, index_t iM, index_t iN) const
Definition: moe_flatmm_kernel.hpp:770
static constexpr index_t NWave
Definition: moe_flatmm_kernel.hpp:231
static constexpr auto I2
Definition: moe_flatmm_kernel.hpp:218
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: moe_flatmm_kernel.hpp:194
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, [[maybe_unused]] const index_t coord_m, const index_t coord_n)
Definition: moe_flatmm_kernel.hpp:697
Definition: flatmm_kernel.hpp:187
ScaleM scale_m
Definition: flatmm_kernel.hpp:219
ScaleN scale_n
Definition: flatmm_kernel.hpp:220
Definition: integral_constant.hpp:13
Definition: unary_element_wise_operation.hpp:1014
Definition: type_traits.hpp:115
Definition: moe_flatmm_kernel.hpp:140
CK_TILE_HOST_DEVICE T operator()(T gate, T linear=1) const
Definition: moe_flatmm_kernel.hpp:142
Definition: moe_flatmm_kernel.hpp:150
const float alpha
Definition: moe_flatmm_kernel.hpp:151
const float limit
Definition: moe_flatmm_kernel.hpp:152
CK_TILE_HOST_DEVICE Swiglu(float alpha_=1.702f, float limit_=7.0f)
Definition: moe_flatmm_kernel.hpp:155
CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
Definition: moe_flatmm_kernel.hpp:161
Definition: numeric.hpp:81
Definition: sequence.hpp:49
Definition: space_filling_curve.hpp:20
Definition: functional.hpp:43
Class creating 2D static tile distribution with different load/store patterns.
Definition: static_encoding_pattern.hpp:130
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192