/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp Source File
moe_flatmm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 #include "ck_tile/host.hpp"
12 
13 // #define disable_tile_gs
14 
15 namespace ck_tile {
16 
17 template <class ScaleM = FlatmmScalePointer<-1>,
18  class ScaleN = FlatmmScalePointer<-1>,
19  class ExpertBias = FlatmmScalePointer<-1>>
20 struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
21 {
31  ExpertBias exp_bias;
32 
33  CK_TILE_HOST MoeFlatmmHostArgs() noexcept = default;
34 
35  CK_TILE_HOST MoeFlatmmHostArgs(const ck_tile::index_t* p_sorted_token_ids_,
36  const void* p_sorted_expert_weights_,
37  const ck_tile::index_t* p_sorted_expert_ids_,
38  const ck_tile::index_t* p_max_token_id_,
39  const void* a_ptr_,
40  const void* b_ptr_,
41  void* c_ptr_,
42  ck_tile::index_t NumTokens_,
43  ck_tile::index_t NumExperts_,
44  ck_tile::index_t TopK_,
45  ck_tile::index_t k_batch_,
46  ck_tile::index_t M_,
47  ck_tile::index_t N_,
48  ck_tile::index_t K_,
49  ck_tile::index_t stride_A_,
50  ck_tile::index_t stride_B_,
51  ck_tile::index_t stride_C_,
52  ScaleM scale_m_ = {},
53  ScaleN scale_n_ = {},
54  ExpertBias exp_bias_ = {})
55  : MoeFlatmmHostArgs(p_sorted_token_ids_,
56  p_sorted_expert_weights_,
57  p_sorted_expert_ids_,
58  p_max_token_id_,
59  a_ptr_,
60  b_ptr_,
61  c_ptr_,
62  NumTokens_,
63  NumExperts_,
64  TopK_,
65  k_batch_,
66  M_,
67  N_,
68  K_,
69  stride_A_,
70  stride_B_,
71  stride_C_,
72  0, // n_padded_zeros_
73  0, // k_padded_zeros_
74  scale_m_,
75  scale_n_,
76  exp_bias_)
77  {
78  }
79 
80  CK_TILE_HOST MoeFlatmmHostArgs(const ck_tile::index_t* p_sorted_token_ids_,
81  const void* p_sorted_expert_weights_,
82  const ck_tile::index_t* p_sorted_expert_ids_,
83  const ck_tile::index_t* p_max_token_id_,
84  const void* a_ptr_,
85  const void* b_ptr_,
86  void* c_ptr_,
87  ck_tile::index_t NumTokens_,
88  ck_tile::index_t NumExperts_,
89  ck_tile::index_t TopK_,
90  ck_tile::index_t k_batch_,
94  ck_tile::index_t stride_A_,
95  ck_tile::index_t stride_B_,
96  ck_tile::index_t stride_C_,
97  ck_tile::index_t n_padded_zeros_ = 0,
98  ck_tile::index_t k_padded_zeros_ = 0,
99  ScaleM scale_m_ = {},
100  ScaleN scale_n_ = {},
101  ExpertBias exp_bias_ = {})
102  : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>(a_ptr_,
103  b_ptr_,
104  {}, // d_ptr_array
105  c_ptr_,
106  k_batch_,
107  M_,
108  N_,
109  K_,
110  stride_A_,
111  stride_B_,
112  {}, // d_stride_array
113  stride_C_,
114  scale_m_,
115  scale_n_),
116  NumTokens(NumTokens_),
117  NumExperts(NumExperts_),
118  TopK(TopK_),
119  p_sorted_token_ids(p_sorted_token_ids_),
120  p_sorted_expert_ids(p_sorted_expert_ids_),
121  p_max_token_id(p_max_token_id_),
122  p_sorted_expert_weights(p_sorted_expert_weights_),
123  n_padded_zeros(n_padded_zeros_),
124  k_padded_zeros(k_padded_zeros_),
125  exp_bias(exp_bias_)
126  {
127  }
128 };
129 
130 enum class MoeFlatmmKind
131 {
134  kFFN_gemm2,
135 };
136 
137 namespace moe {
138 
139 struct MoeSilu
140 {
141  template <typename T>
142  CK_TILE_HOST_DEVICE T operator()(T gate, T linear = 1) const
143  {
144  ck_tile::element_wise::Silu{}(gate, gate);
145  return gate * linear;
146  };
147 };
148 
149 struct Swiglu
150 {
151  const float alpha;
152  const float limit;
153 
155  Swiglu(float alpha_ = 1.702f, float limit_ = 7.0f) // use value in gpt-oss as default
156  : alpha(alpha_), limit(limit_)
157  {
158  }
159 
160  template <typename T>
161  CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
162  {
163  static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
164  std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
165  std::is_same_v<T, int32_t>,
166  "Data type is not supported by this operation!");
167 
168  constexpr T one = type_convert<T>(1);
169 
170  gate = gate < limit ? gate : limit;
171  linear = linear < limit ? (linear > -limit ? linear : -limit) : limit;
172 
173  if constexpr(std::is_same_v<T, float>)
174  {
175  return gate * __builtin_amdgcn_rcpf(one + ck_tile::exp(alpha * -gate)) * (linear + 1);
176  }
177  else
178  {
179  return gate * (one / (one + ck_tile::exp(alpha * -gate))) * (linear + 1);
180  }
181  }
182 };
183 
184 } // namespace moe
185 
186 template <typename TilePartitioner_,
187  typename FlatmmPipeline_,
188  typename EpiloguePipeline_,
189  MoeFlatmmKind kind,
190  typename FusedActivation = moe::MoeSilu>
192 {
203  static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
204  static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
205 
208  // Below type is actually accumulation data type - the output of block GEMM.
210 
211  using AccDataType = float;
212  using ActivationOp = FusedActivation;
213 
214  static constexpr index_t NumDTensor = DsDataType::size();
215 
216  static constexpr auto I0 = number<0>();
217  static constexpr auto I1 = number<1>();
218  static constexpr auto I2 = number<2>();
219  static constexpr auto I3 = number<3>();
220 
221  static_assert(DsLayout::size() == DsDataType::size(),
222  "The size of DsLayout and DsDataType should be the same");
223 
224  static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
225  static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
226 
227  // static constexpr index_t kBlockSize = EpiloguePipeline::kBlockSize;
228  static constexpr index_t kMPerBlock = EpiloguePipeline::kMPerBlock;
229  static constexpr index_t kNPerBlock = EpiloguePipeline::kNPerBlock;
230  static constexpr index_t MWave = EpiloguePipeline::MWave;
231  static constexpr index_t NWave = EpiloguePipeline::NWave;
232  static constexpr index_t MPerXdl = EpiloguePipeline::MPerXdl;
233  static constexpr index_t NPerXdl = EpiloguePipeline::NPerXdl;
234  static constexpr index_t KPerXdl = EpiloguePipeline::KPerXdl;
235  static constexpr index_t isCTransposed = EpiloguePipeline::isCTransposed;
236  static constexpr index_t kMPerIteration = MPerXdl * MWave;
237  static constexpr index_t kNPerIteration = NPerXdl * NWave;
239 
240  static constexpr int OutputNPerBlock =
241  IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
242 
243  // MXF4_Pipeline only has the of scale B and granularityK is 32
244  static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
245  static constexpr int MXFP4N_Pack = 2;
246  static constexpr int MXFP4K_Pack = 2;
247 
248  static constexpr int N_Pack = MXFP4_Pipeline ? MXFP4N_Pack : 1;
249  static constexpr int K_Pack = MXFP4_Pipeline ? MXFP4K_Pack : 1;
250 
252 
253  template <class ScaleM = FlatmmScalePointer<-1>,
254  class ScaleN = FlatmmScalePointer<-1>,
255  class ExpertBias = FlatmmScalePointer<-1>>
257  {
262  const void* a_ptr;
263  const void* b_ptr;
264  void* e_ptr;
276  ScaleM scale_m;
277  ScaleN scale_n;
278  ExpertBias exp_bias;
279  };
280 
281  template <class ScaleM = FlatmmScalePointer<-1>,
282  class ScaleN = FlatmmScalePointer<-1>,
283  class ExpertBias = FlatmmScalePointer<-1>>
284  CK_TILE_HOST static constexpr auto
286  {
288  hostArgs.p_sorted_expert_ids,
289  hostArgs.p_max_token_id,
290  hostArgs.p_sorted_expert_weights,
291  hostArgs.a_ptr,
292  hostArgs.b_ptr,
293  hostArgs.e_ptr,
294  hostArgs.NumTokens,
295  hostArgs.TopK,
296  hostArgs.M,
297  hostArgs.N,
298  hostArgs.K,
299  hostArgs.stride_A,
300  hostArgs.stride_B,
301  hostArgs.stride_C,
302  hostArgs.k_batch,
303  hostArgs.n_padded_zeros,
304  hostArgs.k_padded_zeros,
305  hostArgs.scale_m,
306  hostArgs.scale_n,
307  hostArgs.exp_bias};
308  }
309 
310  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
311  {
312  return concat(
313  '_', "moe_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
314  }
315 
316  static constexpr auto BlockSize() -> dim3 { return dim3(kBlockSize); }
317 
318  static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
319  {
320  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
321  }
322  template <class MoeFlatmmKernelArgs>
323  static constexpr auto GridSize(const MoeFlatmmKernelArgs& kargs)
324  {
325  if constexpr(UsePersistentKernel)
326  {
327  hipDeviceProp_t prop;
328  int deviceId = 0; // default device
329 
330  constexpr int block_size = MoeFlatmmKernel::BlockSize().x;
331  int dync_smem_size = 0;
332  int maxActiveBlocksPerCU = 0;
333 
334  [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
335 
336  e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
337  &maxActiveBlocksPerCU,
338  reinterpret_cast<void*>(kentry<1, MoeFlatmmKernel, MoeFlatmmKernelArgs>),
339  block_size,
340  dync_smem_size);
341 
342  const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
343  const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
344 
345  // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
346  // << ", persistent_block_size: " << persistent_block_size
347  // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
348 
349  assert(kargs.k_batch == 1);
350  return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
351  }
352  else
353  {
354  return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
355  }
356  }
357 
359  {
360  return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
361  }
363  {
364  return FlatmmPipeline::GetSmemSize();
365  }
366 
368  {
369  template <class KernelArgs>
370  __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
371  {
372  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
373  const index_t K_t = kargs.k_batch * K1;
374  const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
375 
376  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
377  {
378  a_k_split_offset = k_id * KRead;
379  }
380  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
381  {
382  a_k_split_offset = k_id * KRead * kargs.stride_A;
383  }
384 
385  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
386  {
387  b_k_split_offset = k_id * KRead * kargs.stride_B;
388  }
389  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
390  {
391  b_k_split_offset = k_id * KRead;
392  }
393 
394  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
395  {
396  splitted_k = KRead;
397  }
398  else
399  {
400  splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
401  }
402  }
403 
407  };
408 
409  template <typename KernelArgs>
410  CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
411  {
412  if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
414  {
415  if(kargs.k_batch != 1)
416  {
417  std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
418  return false;
419  }
420  }
421  if constexpr(UsePersistentKernel)
422  {
423  if(kargs.k_batch != 1)
424  {
425  std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
426  return false;
427  }
428  }
429 
430  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
431  {
432  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
433  {
434  std::cerr << "Can't support K that is not a multiple of KPerBlock"
435  " without padding!"
436  << std::endl;
437  return false;
438  }
439  if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
440  {
441  std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
442  return false;
443  }
444  }
445  else
446  {
447  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
448  {
449  std::cerr << "Can't support M that is not a multiple of MPerBlock"
450  " without padding!"
451  << std::endl;
452  return false;
453  }
454  if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
455  {
456  std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
457  return false;
458  }
459  }
460 
461  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
462  {
463  // if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
464  // {
465  // std::cerr << "Can't support N that is not a multiple of NPerBlock"
466  // " without padding!"
467  // << std::endl;
468  // return false;
469  // }
470  if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
471  {
472  std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
473  return false;
474  }
475  }
476  else
477  {
478  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
479  {
480  std::cerr << "Can't support K that is not a multiple of KPerBlock"
481  " without padding!"
482  << std::endl;
483  return false;
484  }
485  if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
486  {
487  std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
488  return false;
489  }
490  }
491 
492  bool DTesnorIsValid = {true};
493  static_for<0, NumDTensor, 1>{}([&](auto index) {
494  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
495  if(std::is_same_v<DiLayout, ELayout> == false)
496  {
497  DTesnorIsValid = false;
498  }
499  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
500  {
501  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
502  {
503  CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
504  "NPerBlock without padding!");
505  DTesnorIsValid = false;
506  }
507  if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
508  {
509  CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
510  DTesnorIsValid = false;
511  }
512  }
513  else
514  {
515  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
516  {
517  CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
518  "MPerBlock without padding!");
519 
520  DTesnorIsValid = false;
521  }
522  if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
523  {
524  CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
525  DTesnorIsValid = false;
526  }
527  }
528  });
529 
530  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
531  {
532  if(kargs.stride_C % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
533  {
534  std::cerr << "Can't support N that is not a multiple of NPerBlock"
535  " without padding!"
536  << std::endl;
537  return false;
538  }
539  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
540  {
541  std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
542  return false;
543  }
544  }
545  else
546  {
547  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
548  {
549  std::cerr << "Can't support M that is not a multiple of MPerBlock"
550  " without padding!"
551  << std::endl;
552  return false;
553  }
554  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
555  {
556  std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
557  return false;
558  }
559  }
560  return DTesnorIsValid;
561  }
562 
563  template <memory_operation_enum DstInMemOp = IsInputGemm ? memory_operation_enum::set
565  typename KernelArgs>
566  CK_TILE_DEVICE static auto
568  const BDataType* b_flat_ptr,
569  EDataType* e_ptr,
570  [[maybe_unused]] const AccDataType* exp_weight_ptr,
571  const int expert_id,
572  const KernelArgs& kargs,
573  const SplitKBatchOffset& splitk_batch_offset)
574  {
575  const auto& a_tensor_view = [&]() {
576  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
577  {
578  return make_naive_tensor_view<address_space_enum::global>(
579  a_ptr,
580  make_tuple(IsInputGemm ? kargs.NumTokens : kargs.NumTokens * kargs.TopK,
581  splitk_batch_offset.splitted_k),
582  make_tuple(kargs.stride_A, 1),
583  number<FlatmmPipeline::GetVectorSizeA()>{},
584  number<1>{});
585  }
586  else
587  {
588  return make_naive_tensor_view<address_space_enum::global>(
589  a_ptr,
590  make_tuple(splitk_batch_offset.splitted_k,
591  IsInputGemm ? kargs.NumTokens : kargs.NumTokens * kargs.TopK),
592  make_tuple(kargs.stride_A, 1),
593  number<FlatmmPipeline::GetVectorSizeA()>{},
594  number<1>{});
595  }
596  }();
597 
598  index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
599  index_t kFlatN = kargs.N * kargs.K / kFlatK;
600 
601  const auto& b_flat_tensor_view = [&]() {
602  return make_naive_tensor_view<address_space_enum::global>(
603  b_flat_ptr,
604  make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK),
605  make_tuple(kFlatK, 1),
606  number<FlatmmPipeline::GetVectorSizeB()>{},
607  number<1>{});
608  }();
609 
610  // TODO: enable vector write for C in ColMajor
611  const auto& c_tensor_view = [&]() {
612  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
613  {
614  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
615  e_ptr,
616  make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens,
617  IsGateUp ? kargs.N / 2 : kargs.N),
618  make_tuple(kargs.stride_C, 1),
619  number<EpiloguePipeline::GetVectorSizeC()>{},
620  number<1>{});
621  }
622  else
623  {
624  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
625  e_ptr,
626  make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken,
627  IsGateUp ? kargs.N / 2 : kargs.N),
628  make_tuple(1, kargs.stride_C),
629  number<1>{},
630  number<1>{});
631  }
632  }();
633 
634  auto scale_n = kargs.scale_n;
635  constexpr int GranularityK = decltype(scale_n)::GranularityK;
636 
637  index_t scale_k = GranularityK == 0 ? 1 : (kargs.K + GranularityK - 1) / GranularityK;
638  index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
639  index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
640 
641  using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
642 
643  const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
644  reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
645  make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
646  make_tuple(FlatScaleK, 1),
647  number<8>{},
648  number<1>{});
649 
650  return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view, scale_b_flat_view);
651  }
652 
653  template <typename TensorView>
654  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
655  {
656  const auto& a_pad_view = [&]() {
657  const auto& a_tensor_view = views.at(I0);
658  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
659  {
660  return pad_tensor_view(a_tensor_view,
664  }
665  else
666  {
667  return pad_tensor_view(a_tensor_view,
671  }
672  }();
673 
674  // TODO vector write in for C in ColMajor
675  const auto& c_pad_view = [&]() {
676  const auto& c_tensor_view = views.at(I2);
677  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
678  {
679  return pad_tensor_view(
680  c_tensor_view,
683  }
684  else
685  {
686  return pad_tensor_view(
687  c_tensor_view,
690  }
691  }();
692 
693  return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3));
694  }
695 
696  template <typename PadView>
697  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
698  [[maybe_unused]] const index_t coord_m,
699  const index_t coord_n)
700  {
701  const auto& a_pad_view = views.at(number<0>{});
702  const auto& b_flat_pad_view = views.at(number<1>{});
703  const auto& c_pad_view = views.at(number<2>{});
704 
705  const auto& a_block_window = [&]() {
706  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
707  {
708  return make_tile_window(a_pad_view,
711  {coord_m, 0}); // NOTE!
712  }
713  else
714  {
715  return make_tile_window(a_pad_view,
718  {0, 0}); // NOTE!
719  }
720  }();
721 
722  constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline;
723 
724  const auto& b_flat_block_window =
725  make_tile_window(b_flat_pad_view,
728  {static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
729  (isNonInterleaveGateUp ? 1 : 2)),
730  0});
731 
732  const int output_N_offset = IsGateUp ? coord_n / 2 : coord_n;
733 
734  auto c_block_window = make_tile_window(
735  c_pad_view,
737  {0, // offset_m is included when construct C-scatter-window offsets
738  output_N_offset});
739 
740  constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
741  constexpr int XDLPerLoadScaleB =
742  MXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4
743 
744  auto scale_block_window =
745  make_tile_window(views.at(I3),
747  number<FlatmmPipeline::flatKPerWarp * N_Pack * K_Pack *
748  XDLPerLoadScaleB / GranularityK>{}),
749  {coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
750 
751  return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
752  }
753 
754  template <class MoeFlatmmKernelArgs>
756  {
757  int partition_idx = blockIdx.x;
758  int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
759  do
760  {
761  const auto [block_offset_m, block_offset_n] =
762  TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
763 
764  this->operator()(kargs, block_offset_m, block_offset_n);
765  partition_idx += gridDim.x;
766  } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
767  }
768 
769  template <class MoeFlatmmKernelArgs>
771  {
772 
773  // const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
774  const index_t coord_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
775  const index_t coord_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
776  const index_t max_token_id = kargs.p_max_token_id[0];
777  // allocate LDS
778  __shared__ char smem_ptr_ping[GetSmemPingSize()];
779  __shared__ char smem_ptr_pong[GetSmemPongSize()];
780 
781  const index_t expert_id = kargs.p_sorted_expert_ids[iM];
782 
783  constexpr auto a_dram_dist = FlatmmPipeline::GetADramTileDistribution();
784  const auto a_coord = a_dram_dist.calculate_index(); // 2d thread offset, [i_row, i_col]
785 
786  constexpr ck_tile::index_t DramMRepeat =
787  decltype(a_dram_dist)::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}];
789 
790  constexpr index_t token_id_offset = 24;
791  constexpr index_t token_id_mask = (1 << token_id_offset) - 1;
792 
793  auto row_to_token_idx = [&](auto row_idx) {
794  const index_t fused_token =
795  kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
796  index_t gather_token_id = fused_token & token_id_mask;
797  if constexpr(!IsInputGemm)
798  {
799  gather_token_id = gather_token_id * kargs.TopK + (fused_token >> token_id_offset);
800  }
801  return gather_token_id;
802  };
803 
804  if(coord_m >= max_token_id)
805  return;
806 
807  static_for<0, DramMRepeat, 1>{}([&](auto m0) {
808  const auto row_idx =
809  coord_m + m0 * (TilePartitioner::MPerBlock / DramMRepeat) + a_coord[I0];
810  index_t gather_token_id = row_to_token_idx(row_idx);
811  a_offsets[m0] = std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
812  ? gather_token_id * kargs.stride_A
813  : gather_token_id;
814  });
815 
816  const SplitKBatchOffset splitk_batch_offset(kargs);
817  const long_index_t expert_stride =
818  __builtin_amdgcn_readfirstlane(long_index_t(kargs.N) * kargs.K);
819 
820  const ADataType* a_ptr =
821  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
822  const BDataType* b_flat_ptr =
823  static_cast<const BDataType*>(kargs.b_ptr) +
824  (splitk_batch_offset.b_k_split_offset + expert_stride * expert_id) / WeightPackedSize;
825  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
826 
827  const AccDataType* exp_weight_ptr =
828  static_cast<const AccDataType*>(kargs.p_sorted_expert_weights);
829 
830  const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(
831  a_ptr, b_flat_ptr, e_ptr, exp_weight_ptr, expert_id, kargs, splitk_batch_offset);
832  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
833 
834  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, coord_m, coord_n);
835 
836  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
837 
838  // Run GEMM cooperatively by whole workgroup.
839  const auto& a_block_window = gemm_tile_windows.at(I0);
840  const auto& b_block_window = gemm_tile_windows.at(I1);
841  const auto& scale_block_window = gemm_tile_windows.at(I3);
842 
843  auto a_gather_block_tile =
844  ck_tile::make_tile_scatter_gather(a_block_window.get_bottom_tensor_view(),
845  a_block_window.get_window_lengths(),
846  a_block_window.get_window_origin(),
847  a_dram_dist,
848  a_offsets); // K DRAM tile window for
849 
850  auto c_block_tile = [&] {
851  if constexpr(MXFP4_Pipeline)
852  {
853  // MXFP4_Pipeline uses gate-up interleave 16 layout for weight
854  // so don't need extra processing
855  return FlatmmPipeline{}(a_gather_block_tile,
856  b_block_window,
857  scale_block_window, // weight scale with granularityK = 32
858  num_loop,
859  kargs.k_padded_zeros,
860  smem_ptr_ping,
861  smem_ptr_pong);
862  }
863  else
864  {
865  return FlatmmPipeline{}(a_gather_block_tile,
866  b_block_window,
868  num_loop,
869  smem_ptr_ping,
870  smem_ptr_pong);
871  }
872  }();
873 
874  auto& c_block_window = gemm_tile_windows.at(number<2>{});
875 
876  // Run EpiloguePipeline
877  {
878  using EpiProblem = typename EpiloguePipeline::Problem;
879  using ODataType = typename EpiloguePipeline::ODataType;
880  using CWarpDstr = typename EpiloguePipeline::CWarpDstr;
881 
882  constexpr index_t NumMXdlPerWavePerShuffle = EpiloguePipeline::NumMXdlPerWavePerShuffle;
883  constexpr index_t NumNXdlPerWavePerShuffle = EpiloguePipeline::NumNXdlPerWavePerShuffle;
884  constexpr index_t MPerIterationShuffle = EpiloguePipeline::MPerIterationShuffle;
885  constexpr index_t NPerIterationShuffle = EpiloguePipeline::NPerIterationShuffle;
886 
887  constexpr index_t MRepeat = EpiloguePipeline::MRepeat;
888  constexpr index_t NRepeat = EpiloguePipeline::NRepeat;
889  constexpr index_t OutputNRepeat = IsGateUp ? NRepeat / 2 : NRepeat;
890 
891  [[maybe_unused]] constexpr index_t EpiVectorSizeC = EpiloguePipeline::GetVectorSizeC();
892  [[maybe_unused]] constexpr index_t BlockedXDLN_PerWarp =
893  EpiloguePipeline::BlockedXDLN_PerWarp;
894 
895  static_assert(!IsGateUp || NumNXdlPerWavePerShuffle % 2 == 0);
896 
897  constexpr index_t OutputNumNXdlPerWavePerShuffle =
898  IsGateUp ? NumNXdlPerWavePerShuffle / 2 : NumNXdlPerWavePerShuffle;
899  constexpr index_t LDS_NPerIterationShuffle =
900  IsGateUp ? NPerIterationShuffle / 2 : NPerIterationShuffle;
901 
902  constexpr auto lds_block_desc = make_naive_tensor_descriptor(
905 
906  // EpiloguePipeline::template MakeLdsBlockDescriptor<EpiProblem>();
907  auto o_lds_block = make_tensor_view<address_space_enum::lds>(
908  reinterpret_cast<ODataType*>(smem_ptr_ping), lds_block_desc);
909 
910  constexpr int ScaleGranularityM = decltype(kargs.scale_m)::GranularityMN;
911  constexpr int ScaleGranularityN = decltype(kargs.scale_n)::GranularityMN;
912 
913  constexpr index_t scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
914  : 1; // per-token scale
915  constexpr index_t scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
916  : 1; // per-channel scale
917 
918  auto output_acc_tile_distr =
921  sequence<>,
926  sequence<0, 0>>{},
927  typename CWarpDstr::DstrEncode{}));
928 
929  const auto scale_m_coord =
930  output_acc_tile_distr.calculate_index(); // 2d thread offset, [i_row, i_col]
931 
932  constexpr index_t kM2 = 4; // Val-dim
933  constexpr index_t kM1 = get_warp_size() / NPerXdl; // Thr-dim
934  constexpr index_t kM0 = MPerXdl / kM1 / kM2; // Var-dim
935 
936  constexpr index_t ScaleMRepeat = MRepeat * kM0 * kM2;
938 
939  if constexpr(!MXFP4_Pipeline)
940  static_for<0, MRepeat, 1>{}([&](auto mIter) {
941  static_for<0, kM0, 1>{}([&](auto m0) {
942  static_for<0, kM2, 1>{}([&](auto m2) {
943  const auto row_idx =
944  coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0];
945  scale_m_offsets[mIter * number<kM0 * kM2>{} + m0 * number<kM2>{} + m2] =
946  row_to_token_idx(row_idx);
947  });
948  });
949  });
950 
951  constexpr int DynamicTileOffsetFlag = 0;
952 
953  constexpr bool EnableBias = decltype(kargs.exp_bias)::GranularityMN != -1;
954 
955  auto permute_tensor_view = [&](auto naive_view, auto is_needed_to_permute_N_PACK) {
956  if constexpr(!is_needed_to_permute_N_PACK)
957  {
958  return naive_view;
959  }
960  else
961  {
962  auto view1 = transform_tensor_view(
963  naive_view,
964  make_tuple(
967  number<NRepeat / N_Pack>{},
968  number<NWave>{},
969  number<N_Pack>{},
970  number<NPerXdl>{}))),
973  return transform_tensor_view(
974  view1,
978  number<NRepeat / N_Pack>{},
979  number<N_Pack>{},
980  number<NWave>{},
981  number<NPerXdl>{}))),
984  }
985  };
986 
987  auto scale_m_window =
988  make_tile_scatter_gather(make_naive_tensor_view<address_space_enum::global>(
989  kargs.scale_m.ptr,
990  make_tuple(kargs.M, 1),
991  make_tuple(scale_stride_m, 0),
992  number<1>{}, // gather load can't vectorize
993  number<1>{}),
996  {0, 0}, // offset m is included in gather offsets
997  output_acc_tile_distr,
998  scale_m_offsets);
999 
1000  auto scale_n_window = make_tile_window(
1001  make_naive_tensor_view<address_space_enum::global>(
1002  kargs.scale_n.ptr + expert_id * kargs.N,
1003  make_tuple(1, kargs.N),
1004  make_tuple(0, scale_stride_n),
1005  number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
1006  number<1>{}), // MXF4_Pipeline does't use scale_n, so there is no need to
1007  // permute as n_pack
1009  number < IsGateUp ? TilePartitioner::NPerBlock / 2
1010  : TilePartitioner::NPerBlock > {}),
1011  {0, IsGateUp ? coord_n / 2 : coord_n},
1012  output_acc_tile_distr);
1013 
1014  auto scale_n_up_window = make_tile_window(
1015  make_naive_tensor_view<address_space_enum::global>(
1016  kargs.scale_n.ptr + expert_id * kargs.N + kargs.N / 2,
1017  make_tuple(1, kargs.N),
1018  make_tuple(0, scale_stride_n),
1019  number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
1020  number<1>{}),
1022  number<TilePartitioner::NPerBlock / 2>{}),
1023  {0, coord_n / 2},
1024  output_acc_tile_distr);
1025 
1026  auto exp_bias_view = make_naive_tensor_view<address_space_enum::global>(
1027  kargs.exp_bias.ptr + expert_id * kargs.N,
1028  make_tuple(1, kargs.N),
1029  make_tuple(0, scale_stride_n),
1030  number<FlatmmPipeline::GetVectorSizeB()>{},
1031  number<1>{});
1032 
1033  auto exp_bias_window = make_tile_window(
1034  permute_tensor_view(exp_bias_view, number<(MXFP4_Pipeline && !IsInputGemm)>{}),
1036  number < IsGateUp ? TilePartitioner::NPerBlock / 2
1037  : TilePartitioner::NPerBlock > {}),
1038  {0, IsGateUp ? coord_n / 2 : coord_n},
1039  output_acc_tile_distr);
1040 
1041  auto exp_bias_up_window =
1042  make_tile_window(make_naive_tensor_view<address_space_enum::global>(
1043  kargs.exp_bias.ptr + expert_id * kargs.N + kargs.N / 2,
1044  make_tuple(1, kargs.N),
1045  make_tuple(0, scale_stride_n),
1046  number<FlatmmPipeline::GetVectorSizeB()>{},
1047  number<1>{}),
1049  number<TilePartitioner::NPerBlock / 2>{}),
1050  {0, coord_n / 2},
1051  output_acc_tile_distr);
1052 
1053  auto exp_weight_window =
1054  make_tile_window(make_naive_tensor_view<address_space_enum::global>(
1055  static_cast<const float*>(kargs.p_sorted_expert_weights),
1056  make_tuple(kargs.M, 1),
1057  make_tuple(1, 0),
1058  number<FlatmmPipeline::GetVectorSizeA()>{},
1059  number<1>{}),
1062  {coord_m, 0},
1063  output_acc_tile_distr);
1064 
1065  using ScaleMBuffer = decltype(load_tile(scale_m_window));
1066  using ScaleNBuffer = decltype(load_tile(scale_n_window));
1067  using ExpBiasBuffer = decltype(load_tile(exp_bias_window));
1068  using ExpWeightBuffer = decltype(load_tile(exp_weight_window));
1069 
1070  ScaleMBuffer scale_m_buffer;
1071  ScaleNBuffer scale_n_buffer, scale_n_up_buffer;
1072 
1073  ExpBiasBuffer exp_bias_buffer, exp_bias_up_buffer;
1074  ExpWeightBuffer exp_weight_buffer;
1075 
1076  if constexpr(!MXFP4_Pipeline)
1077  {
1078  scale_m_window.load(scale_m_buffer);
1079  scale_n_buffer = load_tile(scale_n_window);
1080  if constexpr(IsGateUp)
1081  scale_n_up_buffer = load_tile(scale_n_up_window);
1082  }
1083 
1084  if constexpr(EnableBias)
1085  {
1086  exp_bias_buffer = load_tile(exp_bias_window);
1087  if constexpr(IsGateUp)
1088  exp_bias_up_buffer = load_tile(exp_bias_up_window);
1089  }
1090  if constexpr(!IsInputGemm)
1091  exp_weight_buffer = load_tile(exp_weight_window);
1092 
1093  auto in_lds_window = make_tile_window(
1094  o_lds_block,
1096  {0, 0});
1097 
1098  auto out_lds_window = make_tile_window(
1099  o_lds_block,
1101  {0, 0});
1102 
1106 
1107  constexpr index_t num_access = SFC::get_num_of_access();
1108 
1109  static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
1110  "Currently, the CShuffle EpiloguePipeline only supports the Row Major "
1111  "Output layout");
1112 
1113  using TileEncodingPattern = tile_distribution_encoding_pattern_2d<
1114  kBlockSize,
1115  MPerIterationShuffle,
1116  LDS_NPerIterationShuffle,
1117  kind == MoeFlatmmKind::kFFN_gemm2 ? 2 : EpiloguePipeline::GetVectorSizeC(),
1119  EpiProblem::kNumWaveGroups>;
1120 
1121  constexpr auto dram_tile_distribution =
1122  TileEncodingPattern::make_2d_static_tile_distribution();
1123 
1124  constexpr auto LdsTileDistr = [&] {
1125  if constexpr(IsGateUp)
1129  sequence<>,
1131  // merge two contiguous N
1136  sequence<0, 0>>{},
1137  typename CWarpDstr::DstrEncode{}));
1138  else
1140  EpiloguePipeline::MakeLdsDistributionEncode());
1141  }();
1142 
1143  using LDSTileTensor =
1144  decltype(make_static_distributed_tensor<AccDataType>(LdsTileDistr));
1145  LDSTileTensor lds_tile[2];
1146 
1147  constexpr auto c_warp_y_lengths =
1148  to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
1149  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
1150  constexpr int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle *
1151  OutputNumNXdlPerWavePerShuffle;
1152 
1153  auto epi_tile_idx_slice =
1154  [&](const auto& acc_tile_like_tensor, auto epi_m_idx, auto epi_n_idx) {
1155  return acc_tile_like_tensor.get_y_sliced_thread_data(
1156  merge_sequences(sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
1157  epi_n_idx * OutputNumNXdlPerWavePerShuffle>{},
1158  c_warp_y_index_zeros),
1161  c_warp_y_lengths));
1162  };
1163 
1164  auto gate_up_epi_tile_idx_interleave_slice = [&](auto& dest_gate_tensor,
1165  auto& dest_up_tensor,
1166  const auto& acc_tile_like_tensor,
1167  auto epi_m_idx,
1168  auto epi_n_idx) {
1170  dest_gate_tensor.set_y_sliced_thread_data(
1171  merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
1173  acc_tile_like_tensor.get_y_sliced_thread_data(
1175  sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
1176  epi_n_idx * NumNXdlPerWavePerShuffle + 2 * n_xdl>{},
1177  c_warp_y_index_zeros),
1179  c_warp_y_lengths)));
1180  dest_up_tensor.set_y_sliced_thread_data(
1181  merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
1183  acc_tile_like_tensor.get_y_sliced_thread_data(
1185  sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
1186  epi_n_idx * NumNXdlPerWavePerShuffle + 2 * n_xdl + 1>{},
1187  c_warp_y_index_zeros),
1189  c_warp_y_lengths)));
1190  });
1191  };
1192 
1193  auto process_epi_tile = [&](auto lds_stage, auto epi_m, auto epi_n) {
1194  if constexpr(IsGateUp)
1195  {
1196  LDSTileTensor gate_tensor, up_tensor;
1197 
1198  gate_up_epi_tile_idx_interleave_slice(
1199  gate_tensor, up_tensor, c_block_tile, epi_m, epi_n);
1200  auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n);
1201  auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n);
1202  auto epi_scale_n_up = epi_tile_idx_slice(scale_n_up_buffer, epi_m, epi_n);
1203 
1204  auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
1205  auto epi_exp_bias_up = epi_tile_idx_slice(exp_bias_up_buffer, epi_m, epi_n);
1206 
1207  static_for<0, ActVectorSize, 1>{}([&](auto idx) {
1208  if constexpr(!MXFP4_Pipeline)
1209  {
1210  gate_tensor.get_thread_buffer()[idx] *=
1211  epi_scale_m[idx] * epi_scale_n[idx];
1212  up_tensor.get_thread_buffer()[idx] *=
1213  epi_scale_m[idx] * epi_scale_n_up[idx];
1214  }
1215  if constexpr(EnableBias)
1216  {
1217  gate_tensor.get_thread_buffer()[idx] += epi_exp_bias[idx];
1218  up_tensor.get_thread_buffer()[idx] += epi_exp_bias_up[idx];
1219  }
1220  lds_tile[lds_stage].get_thread_buffer().at(idx) =
1221  ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
1222  up_tensor.get_thread_buffer().at(idx));
1223  });
1224  }
1225  else
1226  {
1227  lds_tile[lds_stage].get_thread_buffer() =
1228  epi_tile_idx_slice(c_block_tile, epi_m, epi_n);
1229  auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n);
1230  auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n);
1231  auto epi_exp_weight = epi_tile_idx_slice(exp_weight_buffer, epi_m, epi_n);
1232  auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
1233 
1234  static_for<0, ActVectorSize, 1>{}([&](auto idx) {
1235  if constexpr(!MXFP4_Pipeline)
1236  lds_tile[lds_stage].get_thread_buffer()[idx] *=
1237  epi_scale_m[idx] * epi_scale_n[idx];
1238  if constexpr(EnableBias)
1239  lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
1240  if constexpr(!IsInputGemm)
1241  lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
1242  else // for mlp1 gate-only
1243  lds_tile[lds_stage].get_thread_buffer()[idx] =
1244  ActivationOp{}(lds_tile[lds_stage].get_thread_buffer()[idx]);
1245  });
1246  }
1247  };
1248 
1249  constexpr int NumMEpiTile = MRepeat / NumMXdlPerWavePerShuffle;
1250  constexpr int MPerThread = TileEncodingPattern::Y2;
1252  c_scatter_offsets;
1253  auto c_coord = dram_tile_distribution.calculate_index();
1254  static_for<0, NumMEpiTile, 1>{}([&](auto mIter) {
1255  static_for<0, MPerThread, 1>{}([&](auto m0) {
1256  auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0;
1257  auto fused_token =
1258  kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
1259 
1260  index_t scatter_token_id = fused_token & token_id_mask;
1261  if constexpr(IsInputGemm)
1262  scatter_token_id =
1263  scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
1264  c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
1265  });
1266  });
1267 
1268  //===----------------------------------------------------------------------===//
1269  // Pingpong process start
1270  //===----------------------------------------------------------------------===//
1271  process_epi_tile(number<0>{}, number<0>{}, number<0>{});
1272 
1273  static_for<0, num_access, 1>{}([&](auto iAccess) {
1274  constexpr int read_stage = iAccess % 2;
1275  constexpr int write_stage = read_stage ^ 1;
1276 
1277  block_sync_lds();
1278  constexpr auto idx_y_start = SFC::get_index(number<iAccess.value>{});
1279  constexpr auto mIter = number<idx_y_start.at(number<0>{}) / MPerIterationShuffle>{};
1280 
1281  const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile[read_stage]);
1282 
1283  store_tile(in_lds_window, c_warptile_in_tensor_casted);
1284 
1285  if constexpr(iAccess < num_access - 1)
1286  {
1287  constexpr auto idx_y_start_next = SFC::get_index(number<iAccess.value + 1>{});
1288  constexpr auto mIter_next =
1289  number<idx_y_start_next.at(number<0>{}) / MPerIterationShuffle>{};
1290  constexpr auto nIter_next =
1291  number<idx_y_start_next.at(number<1>{}) / NPerIterationShuffle>{};
1292 
1293  process_epi_tile(number<write_stage>{}, mIter_next, nIter_next);
1294  }
1295 
1296  block_sync_lds();
1297 
1298  auto c_out_tensor =
1299  load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
1300  auto c_scatter_tile_window =
1301  make_tile_scatter_gather(c_block_window.get_bottom_tensor_view(),
1302  c_block_window.get_window_lengths(),
1303  c_block_window.get_window_origin(),
1304  dram_tile_distribution,
1305  c_scatter_offsets[mIter]);
1306 
1307  if constexpr(!IsInputGemm ||
1308  EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)
1309  c_scatter_tile_window.update(c_out_tensor);
1310  else
1311  c_scatter_tile_window.store(c_out_tensor);
1312 
1313  if constexpr(iAccess != num_access - 1)
1314  {
1315  constexpr auto step = SFC::get_forward_step(iAccess);
1316  // row_offset of out windows has been included in scatter offset
1317  move_tile_window(c_block_window,
1318  {0, step.at(number<1>{}) / number < IsGateUp ? 2 : 1 > {}});
1319  }
1320  });
1321  }
1322  }
1323 };
1324 
1325 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:245
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
@ thread_raked
Thread raked pattern.
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1609
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:419
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1055
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:826
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
MoeFlatmmKind
Definition: moe_flatmm_kernel.hpp:131
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:906
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1026
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
index_t N
Definition: flatmm_kernel.hpp:170
const void * a_ptr
Definition: flatmm_kernel.hpp:161
index_t stride_B
Definition: flatmm_kernel.hpp:173
index_t stride_C
Definition: flatmm_kernel.hpp:178
index_t K
Definition: flatmm_kernel.hpp:171
const void * b_ptr
Definition: flatmm_kernel.hpp:162
index_t k_batch
Definition: flatmm_kernel.hpp:181
index_t stride_A
Definition: flatmm_kernel.hpp:172
void * e_ptr
Definition: flatmm_kernel.hpp:166
index_t M
Definition: flatmm_kernel.hpp:169
Definition: flatmm_kernel.hpp:33
Definition: moe_flatmm_kernel.hpp:21
ck_tile::index_t NumExperts
Definition: moe_flatmm_kernel.hpp:23
const void * p_sorted_expert_weights
Definition: moe_flatmm_kernel.hpp:28
const ck_tile::index_t * p_max_token_id
Definition: moe_flatmm_kernel.hpp:27
ck_tile::index_t NumTokens
Definition: moe_flatmm_kernel.hpp:22
const ck_tile::index_t * p_sorted_expert_ids
Definition: moe_flatmm_kernel.hpp:26
ExpertBias exp_bias
Definition: moe_flatmm_kernel.hpp:31
const ck_tile::index_t n_padded_zeros
Definition: moe_flatmm_kernel.hpp:29
const ck_tile::index_t * p_sorted_token_ids
Definition: moe_flatmm_kernel.hpp:25
const ck_tile::index_t k_padded_zeros
Definition: moe_flatmm_kernel.hpp:30
CK_TILE_HOST MoeFlatmmHostArgs(const ck_tile::index_t *p_sorted_token_ids_, const void *p_sorted_expert_weights_, const ck_tile::index_t *p_sorted_expert_ids_, const ck_tile::index_t *p_max_token_id_, const void *a_ptr_, const void *b_ptr_, void *c_ptr_, ck_tile::index_t NumTokens_, ck_tile::index_t NumExperts_, ck_tile::index_t TopK_, ck_tile::index_t k_batch_, ck_tile::index_t M_, ck_tile::index_t N_, ck_tile::index_t K_, ck_tile::index_t stride_A_, ck_tile::index_t stride_B_, ck_tile::index_t stride_C_, ck_tile::index_t n_padded_zeros_=0, ck_tile::index_t k_padded_zeros_=0, ScaleM scale_m_={}, ScaleN scale_n_={}, ExpertBias exp_bias_={})
Definition: moe_flatmm_kernel.hpp:80
CK_TILE_HOST MoeFlatmmHostArgs() noexcept=default
ck_tile::index_t TopK
Definition: moe_flatmm_kernel.hpp:24
Definition: moe_flatmm_kernel.hpp:257
ck_tile::index_t K
Definition: moe_flatmm_kernel.hpp:269
ExpertBias exp_bias
Definition: moe_flatmm_kernel.hpp:278
ck_tile::index_t stride_B
Definition: moe_flatmm_kernel.hpp:271
ScaleM scale_m
Definition: moe_flatmm_kernel.hpp:276
ck_tile::index_t k_padded_zeros
Definition: moe_flatmm_kernel.hpp:275
const void * b_ptr
Definition: moe_flatmm_kernel.hpp:263
ck_tile::index_t stride_A
Definition: moe_flatmm_kernel.hpp:270
ck_tile::index_t k_batch
Definition: moe_flatmm_kernel.hpp:273
ck_tile::index_t stride_C
Definition: moe_flatmm_kernel.hpp:272
void * e_ptr
Definition: moe_flatmm_kernel.hpp:264
const ck_tile::index_t * p_max_token_id
Definition: moe_flatmm_kernel.hpp:260
ScaleN scale_n
Definition: moe_flatmm_kernel.hpp:277
ck_tile::index_t NumTokens
Definition: moe_flatmm_kernel.hpp:265
ck_tile::index_t M
Definition: moe_flatmm_kernel.hpp:267
ck_tile::index_t n_padded_zeros
Definition: moe_flatmm_kernel.hpp:274
ck_tile::index_t TopK
Definition: moe_flatmm_kernel.hpp:266
const ck_tile::index_t * p_sorted_token_ids
Definition: moe_flatmm_kernel.hpp:258
const ck_tile::index_t * p_sorted_expert_ids
Definition: moe_flatmm_kernel.hpp:259
const void * a_ptr
Definition: moe_flatmm_kernel.hpp:262
ck_tile::index_t N
Definition: moe_flatmm_kernel.hpp:268
const void * p_sorted_expert_weights
Definition: moe_flatmm_kernel.hpp:261
Definition: moe_flatmm_kernel.hpp:368
index_t splitted_k
Definition: moe_flatmm_kernel.hpp:406
index_t b_k_split_offset
Definition: moe_flatmm_kernel.hpp:405
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: moe_flatmm_kernel.hpp:370
index_t a_k_split_offset
Definition: moe_flatmm_kernel.hpp:404
Definition: moe_flatmm_kernel.hpp:192
static constexpr int OutputNPerBlock
Definition: moe_flatmm_kernel.hpp:240
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: moe_flatmm_kernel.hpp:209
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: moe_flatmm_kernel.hpp:202
static constexpr index_t NumDTensor
Definition: moe_flatmm_kernel.hpp:214
float AccDataType
Definition: moe_flatmm_kernel.hpp:211
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: moe_flatmm_kernel.hpp:196
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: moe_flatmm_kernel.hpp:201
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: moe_flatmm_kernel.hpp:362
static constexpr auto GridSize(const MoeFlatmmKernelArgs &kargs)
Definition: moe_flatmm_kernel.hpp:323
static constexpr auto I1
Definition: moe_flatmm_kernel.hpp:217
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: moe_flatmm_kernel.hpp:193
static constexpr auto I3
Definition: moe_flatmm_kernel.hpp:219
static constexpr index_t kBlockSize
Definition: moe_flatmm_kernel.hpp:203
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: moe_flatmm_kernel.hpp:358
static constexpr bool IsInputGemm
Definition: moe_flatmm_kernel.hpp:224
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: moe_flatmm_kernel.hpp:654
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: moe_flatmm_kernel.hpp:198
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: moe_flatmm_kernel.hpp:197
static constexpr int MXFP4N_Pack
Definition: moe_flatmm_kernel.hpp:245
static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: moe_flatmm_kernel.hpp:318
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, EDataType *e_ptr, [[maybe_unused]] const AccDataType *exp_weight_ptr, const int expert_id, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: moe_flatmm_kernel.hpp:567
static constexpr bool UsePersistentKernel
Definition: moe_flatmm_kernel.hpp:204
FusedActivation ActivationOp
Definition: moe_flatmm_kernel.hpp:212
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: moe_flatmm_kernel.hpp:207
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: moe_flatmm_kernel.hpp:199
static constexpr bool MXFP4_Pipeline
Definition: moe_flatmm_kernel.hpp:244
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: moe_flatmm_kernel.hpp:206
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: moe_flatmm_kernel.hpp:200
static constexpr index_t kMPerBlock
Definition: moe_flatmm_kernel.hpp:228
static constexpr index_t MWave
Definition: moe_flatmm_kernel.hpp:230
static constexpr index_t KPerXdl
Definition: moe_flatmm_kernel.hpp:234
static constexpr auto BlockSize() -> dim3
Definition: moe_flatmm_kernel.hpp:316
static constexpr bool IsGateUp
Definition: moe_flatmm_kernel.hpp:225
static constexpr index_t kNPerBlock
Definition: moe_flatmm_kernel.hpp:229
static CK_TILE_HOST const std::string GetName()
Definition: moe_flatmm_kernel.hpp:310
static constexpr CK_TILE_HOST auto MakeKernelArgs(const MoeFlatmmHostArgs< ScaleM, ScaleN, ExpertBias > &hostArgs)
Definition: moe_flatmm_kernel.hpp:285
static constexpr index_t NPerXdl
Definition: moe_flatmm_kernel.hpp:233
static constexpr index_t kNPerIteration
Definition: moe_flatmm_kernel.hpp:237
static constexpr index_t kMPerIteration
Definition: moe_flatmm_kernel.hpp:236
static constexpr int WeightPackedSize
Definition: moe_flatmm_kernel.hpp:251
static constexpr auto I0
Definition: moe_flatmm_kernel.hpp:216
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: moe_flatmm_kernel.hpp:410
static constexpr index_t isCTransposed
Definition: moe_flatmm_kernel.hpp:235
static constexpr int K_Pack
Definition: moe_flatmm_kernel.hpp:249
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const
Definition: moe_flatmm_kernel.hpp:755
static constexpr int N_Pack
Definition: moe_flatmm_kernel.hpp:248
static constexpr int MXFP4K_Pack
Definition: moe_flatmm_kernel.hpp:246
static constexpr index_t kNRepeat
Definition: moe_flatmm_kernel.hpp:238
static constexpr index_t MPerXdl
Definition: moe_flatmm_kernel.hpp:232
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs, index_t iM, index_t iN) const
Definition: moe_flatmm_kernel.hpp:770
static constexpr index_t NWave
Definition: moe_flatmm_kernel.hpp:231
static constexpr auto I2
Definition: moe_flatmm_kernel.hpp:218
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: moe_flatmm_kernel.hpp:194
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, [[maybe_unused]] const index_t coord_m, const index_t coord_n)
Definition: moe_flatmm_kernel.hpp:697
Definition: flatmm_kernel.hpp:187
ScaleM scale_m
Definition: flatmm_kernel.hpp:219
ScaleN scale_n
Definition: flatmm_kernel.hpp:220
Definition: integral_constant.hpp:13
Definition: unary_element_wise_operation.hpp:1014
Definition: type_traits.hpp:115
Definition: moe_flatmm_kernel.hpp:140
CK_TILE_HOST_DEVICE T operator()(T gate, T linear=1) const
Definition: moe_flatmm_kernel.hpp:142
Definition: moe_flatmm_kernel.hpp:150
const float alpha
Definition: moe_flatmm_kernel.hpp:151
const float limit
Definition: moe_flatmm_kernel.hpp:152
CK_TILE_HOST_DEVICE Swiglu(float alpha_=1.702f, float limit_=7.0f)
Definition: moe_flatmm_kernel.hpp:155
CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
Definition: moe_flatmm_kernel.hpp:161
Definition: numeric.hpp:81
Definition: sequence.hpp:49
Definition: space_filling_curve.hpp:20
Definition: functional.hpp:43
Class creating 2D static tile distribution with different load/store patterns.
Definition: static_encoding_pattern.hpp:130
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192