/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp Source File
moe_flatmm_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
11 #include "ck_tile/host.hpp"
12 
13 // #define disable_tile_gs
14 
15 namespace ck_tile {
16 
17 template <class ScaleM = FlatmmScalePointer<-1>,
18  class ScaleN = FlatmmScalePointer<-1>,
19  class ExpertBias = FlatmmScalePointer<-1>>
20 struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
21 {
31  ExpertBias exp_bias;
32 
33  CK_TILE_HOST MoeFlatmmHostArgs() noexcept = default;
34 
35  CK_TILE_HOST MoeFlatmmHostArgs(const ck_tile::index_t* p_sorted_token_ids_,
36  const void* p_sorted_expert_weights_,
37  const ck_tile::index_t* p_sorted_expert_ids_,
38  const ck_tile::index_t* p_max_token_id_,
39  const void* a_ptr_,
40  const void* b_ptr_,
41  void* c_ptr_,
42  ck_tile::index_t NumTokens_,
43  ck_tile::index_t NumExperts_,
44  ck_tile::index_t TopK_,
45  ck_tile::index_t k_batch_,
46  ck_tile::index_t M_,
47  ck_tile::index_t N_,
48  ck_tile::index_t K_,
49  ck_tile::index_t stride_A_,
50  ck_tile::index_t stride_B_,
51  ck_tile::index_t stride_C_,
52  ScaleM scale_m_ = {},
53  ScaleN scale_n_ = {},
54  ExpertBias exp_bias_ = {})
55  : MoeFlatmmHostArgs(p_sorted_token_ids_,
56  p_sorted_expert_weights_,
57  p_sorted_expert_ids_,
58  p_max_token_id_,
59  a_ptr_,
60  b_ptr_,
61  c_ptr_,
62  NumTokens_,
63  NumExperts_,
64  TopK_,
65  k_batch_,
66  M_,
67  N_,
68  K_,
69  stride_A_,
70  stride_B_,
71  stride_C_,
72  0, // n_padded_zeros_
73  0, // k_padded_zeros_
74  scale_m_,
75  scale_n_,
76  exp_bias_)
77  {
78  }
79 
80  CK_TILE_HOST MoeFlatmmHostArgs(const ck_tile::index_t* p_sorted_token_ids_,
81  const void* p_sorted_expert_weights_,
82  const ck_tile::index_t* p_sorted_expert_ids_,
83  const ck_tile::index_t* p_max_token_id_,
84  const void* a_ptr_,
85  const void* b_ptr_,
86  void* c_ptr_,
87  ck_tile::index_t NumTokens_,
88  ck_tile::index_t NumExperts_,
89  ck_tile::index_t TopK_,
90  ck_tile::index_t k_batch_,
94  ck_tile::index_t stride_A_,
95  ck_tile::index_t stride_B_,
96  ck_tile::index_t stride_C_,
97  ck_tile::index_t n_padded_zeros_ = 0,
98  ck_tile::index_t k_padded_zeros_ = 0,
99  ScaleM scale_m_ = {},
100  ScaleN scale_n_ = {},
101  ExpertBias exp_bias_ = {})
102  : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>(a_ptr_,
103  b_ptr_,
104  {}, // d_ptr_array
105  c_ptr_,
106  k_batch_,
107  M_,
108  N_,
109  K_,
110  stride_A_,
111  stride_B_,
112  {}, // d_stride_array
113  stride_C_,
114  scale_m_,
115  scale_n_),
116  NumTokens(NumTokens_),
117  NumExperts(NumExperts_),
118  TopK(TopK_),
119  p_sorted_token_ids(p_sorted_token_ids_),
120  p_sorted_expert_ids(p_sorted_expert_ids_),
121  p_max_token_id(p_max_token_id_),
122  p_sorted_expert_weights(p_sorted_expert_weights_),
123  n_padded_zeros(n_padded_zeros_),
124  k_padded_zeros(k_padded_zeros_),
125  exp_bias(exp_bias_)
126  {
127  }
128 };
129 
130 enum class MoeFlatmmKind
131 {
134  kFFN_gemm2,
136 };
137 
138 namespace moe {
139 
140 struct MoeSilu
141 {
142  template <typename T>
143  CK_TILE_HOST_DEVICE T operator()(T gate, T linear = 1) const
144  {
145  ck_tile::element_wise::Silu{}(gate, gate);
146  return gate * linear;
147  };
148 };
149 
150 struct Swiglu
151 {
152  const float alpha;
153  const float limit;
154 
156  Swiglu(float alpha_ = 1.702f, float limit_ = 7.0f) // use value in gpt-oss as default
157  : alpha(alpha_), limit(limit_)
158  {
159  }
160 
161  template <typename T>
162  CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
163  {
164  static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
165  std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
166  std::is_same_v<T, int32_t>,
167  "Data type is not supported by this operation!");
168 
169  constexpr T one = type_convert<T>(1);
170 
171  gate = gate < limit ? gate : limit;
172  linear = linear < limit ? (linear > -limit ? linear : -limit) : limit;
173 
174  if constexpr(std::is_same_v<T, float>)
175  {
176  return gate * __builtin_amdgcn_rcpf(one + ck_tile::exp(alpha * -gate)) * (linear + 1);
177  }
178  else
179  {
180  return gate * (one / (one + ck_tile::exp(alpha * -gate))) * (linear + 1);
181  }
182  }
183 };
184 
185 } // namespace moe
186 
187 template <typename TilePartitioner_,
188  typename FlatmmPipeline_,
189  typename EpiloguePipeline_,
190  MoeFlatmmKind kind,
191  typename FusedActivation = moe::MoeSilu>
193 {
204  static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
205  static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
206 
209  // Below type is actually accumulation data type - the output of block GEMM.
211 
212  using AccDataType = float;
213  using ActivationOp = FusedActivation;
214 
215  static constexpr index_t NumDTensor = DsDataType::size();
216 
217  static constexpr auto I0 = number<0>();
218  static constexpr auto I1 = number<1>();
219  static constexpr auto I2 = number<2>();
220  static constexpr auto I3 = number<3>();
221  static constexpr auto I4 = number<4>();
222 
223  static_assert(DsLayout::size() == DsDataType::size(),
224  "The size of DsLayout and DsDataType should be the same");
225 
226  static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
227  static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
228  static constexpr bool IsGemm1SplitK = kind == MoeFlatmmKind::kFFN_gemm1_split_k;
229  static constexpr bool IsBShuffled = true;
230 
231  // static constexpr index_t kBlockSize = EpiloguePipeline::kBlockSize;
232  static constexpr index_t kMPerBlock = EpiloguePipeline::kMPerBlock;
233  static constexpr index_t kNPerBlock = EpiloguePipeline::kNPerBlock;
234  static constexpr index_t MWave = EpiloguePipeline::MWave;
235  static constexpr index_t NWave = EpiloguePipeline::NWave;
236  static constexpr index_t MPerXdl = EpiloguePipeline::MPerXdl;
237  static constexpr index_t NPerXdl = EpiloguePipeline::NPerXdl;
238  static constexpr index_t KPerXdl = EpiloguePipeline::KPerXdl;
239  static constexpr index_t isCTransposed = EpiloguePipeline::isCTransposed;
240  static constexpr index_t kMPerIteration = MPerXdl * MWave;
241  static constexpr index_t kNPerIteration = NPerXdl * NWave;
243 
244  static constexpr int OutputNPerBlock =
245  IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
246 
247  // MXF4_Pipeline only has the of scale B and granularityK is 32
248  static constexpr bool AQUANT_Pipeline = std::is_same_v<ADataType, bf8_t> ||
249  std::is_same_v<ADataType, fp8_t> ||
250  std::is_same_v<ADataType, pk_fp4_t>;
251  static constexpr bool BMXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
252 
253  static constexpr bool MXF8F6F4MFMA =
254 #ifdef __gfx950__
256 #else
257  false;
258 #endif
259  static constexpr int MXFP4M_Pack = 2;
260  static constexpr int MXFP4N_Pack = 2;
261  static constexpr int MXFP4K_Pack = 2;
262 
263  static constexpr int M_Pack = AQUANT_Pipeline ? MXFP4M_Pack : 1;
264  static constexpr int N_Pack = BMXFP4_Pipeline ? MXFP4N_Pack : 1;
265  static constexpr int K_Pack = BMXFP4_Pipeline ? MXFP4K_Pack : 1;
266 
268 
269  template <class ScaleM = FlatmmScalePointer<-1>,
270  class ScaleN = FlatmmScalePointer<-1>,
271  class ExpertBias = FlatmmScalePointer<-1>>
273  {
278  const void* a_ptr;
279  const void* b_ptr;
280  void* e_ptr;
292  ScaleM scale_m;
293  ScaleN scale_n;
294  ExpertBias exp_bias;
295  };
296 
297  template <class ScaleM = FlatmmScalePointer<-1>,
298  class ScaleN = FlatmmScalePointer<-1>,
299  class ExpertBias = FlatmmScalePointer<-1>>
300  CK_TILE_HOST static constexpr auto
302  {
304  hostArgs.p_sorted_expert_ids,
305  hostArgs.p_max_token_id,
306  hostArgs.p_sorted_expert_weights,
307  hostArgs.a_ptr,
308  hostArgs.b_ptr,
309  hostArgs.e_ptr,
310  hostArgs.NumTokens,
311  hostArgs.TopK,
312  hostArgs.M,
313  hostArgs.N,
314  hostArgs.K,
315  hostArgs.stride_A,
316  hostArgs.stride_B,
317  hostArgs.stride_C,
318  hostArgs.k_batch,
319  hostArgs.n_padded_zeros,
320  hostArgs.k_padded_zeros,
321  hostArgs.scale_m,
322  hostArgs.scale_n,
323  hostArgs.exp_bias};
324  }
325 
326  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
327  {
328  return concat(
329  '_', "moe_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
330  }
331 
332  static constexpr auto BlockSize() -> dim3 { return dim3(kBlockSize); }
333 
334  static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
335  {
336  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
337  }
338  template <class MoeFlatmmKernelArgs>
339  static constexpr auto GridSize(const MoeFlatmmKernelArgs& kargs)
340  {
341  if constexpr(UsePersistentKernel)
342  {
343  hipDeviceProp_t prop;
344  int deviceId = 0; // default device
345 
346  constexpr int block_size = MoeFlatmmKernel::BlockSize().x;
347  int dync_smem_size = 0;
348  int maxActiveBlocksPerCU = 0;
349 
350  [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
351 
352  e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
353  &maxActiveBlocksPerCU,
354  reinterpret_cast<void*>(kentry<1, MoeFlatmmKernel, MoeFlatmmKernelArgs>),
355  block_size,
356  dync_smem_size);
357 
358  const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
359  const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
360 
361  // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
362  // << ", persistent_block_size: " << persistent_block_size
363  // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
364 
365  assert(kargs.k_batch == 1);
366  return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
367  }
368  else
369  {
370  return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
371  }
372  }
373 
375  {
376  return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
377  }
379  {
380  return FlatmmPipeline::GetSmemSize();
381  }
382 
384  {
385  template <class KernelArgs>
386  __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
387  {
388  constexpr auto K1 = BlockGemmShape::WarpTile::at(number<2>{});
389  const index_t K_t = kargs.k_batch * K1;
390  const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
391 
392  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
393  {
394  a_k_split_offset = k_id * KRead;
395  }
396  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
397  {
398  a_k_split_offset = k_id * KRead * kargs.stride_A;
399  }
400 
401  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
402  {
403  splitted_k = KRead;
404  }
405  else
406  {
407  splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
408  }
409 
410  if constexpr(IsBShuffled)
411  {
413  }
414  else
415  {
416  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
417  {
418  b_k_split_offset = k_id * KRead * kargs.stride_B;
419  }
420  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
421  {
422  b_k_split_offset = k_id * KRead;
423  }
424  }
425  }
426 
430  };
431 
432  template <typename KernelArgs>
433  CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
434  {
435  if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
437  {
438  if(kargs.k_batch != 1)
439  {
440  std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
441  return false;
442  }
443  }
444  if constexpr(UsePersistentKernel)
445  {
446  if(kargs.k_batch != 1)
447  {
448  std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
449  return false;
450  }
451  }
452 
453  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
454  {
455  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
456  {
457  std::cerr << "Can't support K that is not a multiple of KPerBlock"
458  " without padding!"
459  << std::endl;
460  return false;
461  }
462  if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
463  {
464  std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
465  return false;
466  }
467  }
468  else
469  {
470  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
471  {
472  std::cerr << "Can't support M that is not a multiple of MPerBlock"
473  " without padding!"
474  << std::endl;
475  return false;
476  }
477  if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
478  {
479  std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
480  return false;
481  }
482  }
483 
484  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
485  {
486  // if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
487  // {
488  // std::cerr << "Can't support N that is not a multiple of NPerBlock"
489  // " without padding!"
490  // << std::endl;
491  // return false;
492  // }
493  if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
494  {
495  std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
496  return false;
497  }
498  }
499  else
500  {
501  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
502  {
503  std::cerr << "Can't support K that is not a multiple of KPerBlock"
504  " without padding!"
505  << std::endl;
506  return false;
507  }
508  if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
509  {
510  std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
511  return false;
512  }
513  }
514 
515  bool DTesnorIsValid = {true};
516  static_for<0, NumDTensor, 1>{}([&](auto index) {
517  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
518  if(std::is_same_v<DiLayout, ELayout> == false)
519  {
520  DTesnorIsValid = false;
521  }
522  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
523  {
524  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
525  {
526  CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
527  "NPerBlock without padding!");
528  DTesnorIsValid = false;
529  }
530  if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
531  {
532  CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
533  DTesnorIsValid = false;
534  }
535  }
536  else
537  {
538  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
539  {
540  CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
541  "MPerBlock without padding!");
542 
543  DTesnorIsValid = false;
544  }
545  if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
546  {
547  CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
548  DTesnorIsValid = false;
549  }
550  }
551  });
552 
553  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
554  {
555  if(kargs.stride_C % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
556  {
557  std::cerr << "Can't support N that is not a multiple of NPerBlock"
558  " without padding!"
559  << std::endl;
560  return false;
561  }
562  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
563  {
564  std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
565  return false;
566  }
567  }
568  else
569  {
570  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
571  {
572  std::cerr << "Can't support M that is not a multiple of MPerBlock"
573  " without padding!"
574  << std::endl;
575  return false;
576  }
577  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
578  {
579  std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
580  return false;
581  }
582  }
583  return DTesnorIsValid;
584  }
585 
586  template <memory_operation_enum DstInMemOp = (IsInputGemm && !IsGemm1SplitK)
587  ? memory_operation_enum::set
589  typename KernelArgs>
590  CK_TILE_DEVICE static auto
592  const BDataType* b_flat_ptr,
593  EDataType* e_ptr,
594  [[maybe_unused]] const AccDataType* exp_weight_ptr,
595  [[maybe_unused]] const int expert_id,
596  const KernelArgs& kargs,
597  const SplitKBatchOffset& splitk_batch_offset)
598  {
599  const auto& a_tensor_view = [&]() {
600  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
601  {
602  return make_naive_tensor_view<address_space_enum::global>(
603  a_ptr,
604  make_tuple(IsInputGemm ? kargs.NumTokens : kargs.NumTokens * kargs.TopK,
605  splitk_batch_offset.splitted_k),
606  make_tuple(kargs.stride_A, 1),
607  number<FlatmmPipeline::GetVectorSizeA()>{},
608  number<1>{});
609  }
610  else
611  {
612  return make_naive_tensor_view<address_space_enum::global>(
613  a_ptr,
614  make_tuple(splitk_batch_offset.splitted_k,
615  IsInputGemm ? kargs.NumTokens : kargs.NumTokens * kargs.TopK),
616  make_tuple(kargs.stride_A, 1),
617  number<FlatmmPipeline::GetVectorSizeA()>{},
618  number<1>{});
619  }
620  }();
621 
622  const auto& b_flat_tensor_view = [&]() {
623  if constexpr(!FlatmmPipeline::BPreShufflePermute)
624  {
625  index_t kFlatK =
626  kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
627  index_t kFlatN = kargs.N * kargs.K / kFlatK;
628 
629  return make_naive_tensor_view<address_space_enum::global,
630  memory_operation_enum::set,
631  FlatmmPipeline::BMemNTType>(
632  b_flat_ptr,
633  make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK),
634  make_tuple(kFlatK, 1),
635  number<FlatmmPipeline::GetVectorSizeB()>{},
636  number<1>{});
637  }
638  else
639  {
640  index_t kFlatK = FlatmmPipeline::flatKPerWarp;
641  index_t kFlatN0 = (kargs.N >> 4);
642  index_t kFlatK0 = (kargs.K >> 7);
643 
644  auto b_tensor_view_naive = make_naive_tensor_view<address_space_enum::global,
645  memory_operation_enum::set,
646  FlatmmPipeline::BMemNTType>(
647  b_flat_ptr,
648  make_tuple(kFlatK0, kFlatN0 - kargs.n_padded_zeros / NPerXdl, kFlatK),
649  make_tuple(kFlatK * (kFlatN0 - kargs.n_padded_zeros / NPerXdl), kFlatK, 1),
650  number<FlatmmPipeline::GetVectorSizeB()>{},
651  number<1>{});
652  return transform_tensor_view(
653  b_tensor_view_naive,
654  make_tuple(
655  make_pass_through_transform(kFlatN0 - kargs.n_padded_zeros / NPerXdl),
659  }
660  }();
661 
662  // TODO: enable vector write for C in ColMajor
663  const auto& c_tensor_view = [&]() {
664  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
665  {
666  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
667  e_ptr,
668  make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens,
669  IsGateUp ? kargs.N / 2 : kargs.N),
670  make_tuple(kargs.stride_C, 1),
671  number<EpiloguePipeline::GetVectorSizeC()>{},
672  number<1>{});
673  }
674  else
675  {
676  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
677  e_ptr,
678  make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens,
679  IsGateUp ? kargs.N / 2 : kargs.N),
680  make_tuple(1, kargs.stride_C),
681  number<1>{},
682  number<1>{});
683  }
684  }();
685 
686  const auto& scale_a_tensor_view = [&]() {
687  auto scale_m_desc = kargs.scale_m;
688  if constexpr(AQUANT_Pipeline)
689  {
690  constexpr int AGranularityK = decltype(scale_m_desc)::GranularityK == 0
691  ? 1
692  : decltype(scale_m_desc)::GranularityK;
693 
694  constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
695  constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
696  index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
697  index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
698  // Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
699  const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
700  make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl));
701  const auto scale_a_desc = transform_tensor_descriptor(
702  scale_a_naive_desc,
703  make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)),
704  make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
707  return make_tensor_view<address_space_enum::global>(
708  reinterpret_cast<const int32_t*>(scale_m_desc.ptr), scale_a_desc);
709  }
710  else
711  {
712  constexpr int AGranularityK = 32;
713  constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
714  constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
715  index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
716  index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
717  return make_naive_tensor_view<address_space_enum::global>(
718  reinterpret_cast<const int32_t*>(scale_m_desc.ptr),
719  make_tuple(scale_m_packs * MThreadPerXdl, scale_k_packs * KThreadPerXdl),
720  make_tuple(scale_k_packs * KThreadPerXdl, 1),
721  number<8>{},
722  number<1>{});
723  }
724  }();
725 
726  const auto scale_b_flat_view = [&]() {
727  auto scale_n = kargs.scale_n;
728  constexpr int BGranularityK =
729  decltype(scale_n)::GranularityK == 0 ? 1 : decltype(scale_n)::GranularityK;
730  if constexpr(AQUANT_Pipeline)
731  {
732  index_t scale_k =
733  BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
734  constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(I1);
735  constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I1);
736  index_t scale_n_packs = kargs.N / (MXFP4N_Pack * NThreadPerXdl);
737  index_t scale_k_packs = kargs.K / (MXFP4K_Pack * BGranularityK * KThreadPerXdl);
738  const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
739  make_tuple(scale_n_packs, scale_k_packs, KThreadPerXdl, NThreadPerXdl));
740  const auto scale_b_desc = transform_tensor_descriptor(
741  scale_b_navie_desc,
742  make_tuple(make_merge_transform(make_tuple(scale_n_packs, NThreadPerXdl)),
743  make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
746 
747  return make_tensor_view<address_space_enum::global>(
748  reinterpret_cast<const int32_t*>(scale_n.ptr) +
749  expert_id * kargs.N * scale_k / 4,
750  scale_b_desc);
751  }
752  else
753  {
754  index_t scale_k =
755  BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
756  const auto scale_k_offset =
757  (splitk_batch_offset.b_k_split_offset / BGranularityK) * K_Pack;
758  index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
759  index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
760 
761  return make_naive_tensor_view<address_space_enum::global>(
762  scale_n.ptr + expert_id * kargs.N * scale_k + scale_k_offset,
763  make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
764  make_tuple(FlatScaleK, 1),
765  number<8>{},
766  number<1>{});
767  }
768  }();
769 
770  return make_tuple(a_tensor_view,
771  b_flat_tensor_view,
772  c_tensor_view,
773  scale_a_tensor_view,
774  scale_b_flat_view);
775  }
776 
777  template <typename TensorView>
778  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
779  {
780  const auto& a_pad_view = [&]() {
781  const auto& a_tensor_view = views.at(I0);
782  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
783  {
784  return pad_tensor_view(a_tensor_view,
788  }
789  else
790  {
791  return pad_tensor_view(a_tensor_view,
795  }
796  }();
797 
798  // TODO vector write in for C in ColMajor
799  const auto& c_pad_view = [&]() {
800  const auto& c_tensor_view = views.at(I2);
801  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
802  {
803  return pad_tensor_view(
804  c_tensor_view,
807  }
808  else
809  {
810  return pad_tensor_view(
811  c_tensor_view,
814  }
815  }();
816 
817  return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3), views.at(I4));
818  }
819 
820  template <typename PadView>
821  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
822  [[maybe_unused]] const index_t coord_m,
823  const index_t coord_n)
824  {
825  const auto& a_pad_view = views.at(number<0>{});
826  const auto& b_flat_pad_view = views.at(number<1>{});
827  const auto& c_pad_view = views.at(number<2>{});
828 
829  const auto& a_block_window = [&]() {
830  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
831  {
832  return make_tile_window(a_pad_view,
835  {coord_m, 0}); // NOTE!
836  }
837  else
838  {
839  return make_tile_window(a_pad_view,
842  {0, 0}); // NOTE!
843  }
844  }();
845 
846  constexpr bool isNonInterleaveGateUp = !IsGateUp || BMXFP4_Pipeline;
847 
848  const auto& b_flat_block_window =
849  make_tile_window(b_flat_pad_view,
852  {static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
853  (isNonInterleaveGateUp ? 1 : 2)),
854  0});
855 
856  const int output_N_offset = IsGateUp ? coord_n / 2 : coord_n;
857 
858  auto c_block_window = make_tile_window(
859  c_pad_view,
861  {0, // offset_m is included when construct C-scatter-window offsets
862  output_N_offset});
863 
864  constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
865  auto a_scale_block_window = make_tile_window(
866  views.at(I3),
868  number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
869  {coord_m / M_Pack, 0});
870 
871  constexpr int XDLPerLoadScaleB =
872  BMXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4
873 
874  auto b_scale_block_window = [&]() {
875  if constexpr(MXF8F6F4MFMA)
876  {
877  return make_tile_window(
878  views.at(I4),
880  number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
881  {coord_n / N_Pack, 0});
882  }
883  else
884  {
885  return make_tile_window(
886  views.at(I4),
888  number<FlatmmPipeline::flatKPerWarp * N_Pack * K_Pack *
889  XDLPerLoadScaleB / GranularityK>{}),
890  {coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
891  }
892  }();
893 
894  return make_tuple(a_block_window,
895  b_flat_block_window,
896  c_block_window,
897  a_scale_block_window,
898  b_scale_block_window);
899  }
900 
901  template <class MoeFlatmmKernelArgs>
903  {
904  int partition_idx = blockIdx.x;
905  int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
906  do
907  {
908  const auto [block_offset_m, block_offset_n] =
909  TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
910 
911  this->operator()(kargs, block_offset_m, block_offset_n);
912  partition_idx += gridDim.x;
913  } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
914  }
915 
916  template <class MoeFlatmmKernelArgs>
918  {
919 
920  // const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
921  const index_t coord_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
922  const index_t coord_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
923  const index_t max_token_id = kargs.p_max_token_id[0];
924  // allocate LDS
925  __shared__ char smem_ptr_ping[GetSmemPingSize()];
926  __shared__ char smem_ptr_pong[GetSmemPongSize()];
927 
928  const index_t expert_id = kargs.p_sorted_expert_ids[iM];
929 
930  constexpr auto a_dram_dist = FlatmmPipeline::GetADramTileDistribution();
931  const auto a_coord = a_dram_dist.calculate_index(); // 2d thread offset, [i_row, i_col]
932 
933  constexpr ck_tile::index_t DramMRepeat =
934  decltype(a_dram_dist)::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}];
936 
937  constexpr index_t token_id_offset = 24;
938  constexpr index_t token_id_mask = (1 << token_id_offset) - 1;
939 
940  auto row_to_token_idx = [&](auto row_idx) {
941  const index_t fused_token =
942  kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
943  index_t gather_token_id = fused_token & token_id_mask;
944  if constexpr(!IsInputGemm)
945  {
946  gather_token_id = gather_token_id * kargs.TopK + (fused_token >> token_id_offset);
947  }
948  return gather_token_id;
949  };
950 
951  if(coord_m >= max_token_id)
952  return;
953  static_for<0, DramMRepeat, 1>{}([&](auto m0) {
954  const auto row_idx =
955  coord_m + m0 * (TilePartitioner::MPerBlock / DramMRepeat) + a_coord[I0];
956  index_t gather_token_id = row_to_token_idx(row_idx);
957  a_offsets[m0] = std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
958  ? gather_token_id * kargs.stride_A
959  : gather_token_id;
960  });
961 
962  const SplitKBatchOffset splitk_batch_offset(kargs);
963  const long_index_t expert_stride =
964  __builtin_amdgcn_readfirstlane(long_index_t(kargs.N) * kargs.K);
965 
966  const ADataType* a_ptr =
967  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
968  const BDataType* b_flat_ptr =
969  static_cast<const BDataType*>(kargs.b_ptr) +
970  (splitk_batch_offset.b_k_split_offset + expert_stride * expert_id) / WeightPackedSize;
971  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
972 
973  const AccDataType* exp_weight_ptr =
974  static_cast<const AccDataType*>(kargs.p_sorted_expert_weights);
975 
976  const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(
977  a_ptr, b_flat_ptr, e_ptr, exp_weight_ptr, expert_id, kargs, splitk_batch_offset);
978  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
979 
980  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, coord_m, coord_n);
981 
982  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
983 
984  // Run GEMM cooperatively by whole workgroup.
985  const auto& a_block_window = gemm_tile_windows.at(I0);
986  const auto& b_block_window = gemm_tile_windows.at(I1);
987  const auto& a_scale_block_window = gemm_tile_windows.at(I3);
988  const auto& b_scale_block_window = gemm_tile_windows.at(I4);
989 
990  auto a_gather_block_tile =
991  ck_tile::make_tile_scatter_gather(a_block_window.get_bottom_tensor_view(),
992  a_block_window.get_window_lengths(),
993  a_block_window.get_window_origin(),
994  a_dram_dist,
995  a_offsets); // K DRAM tile window for
996 
997  auto c_block_tile = [&] {
998  if constexpr(BMXFP4_Pipeline)
999  {
1000  // BMXFP4_Pipeline uses gate-up interleave 16 layout for weight
1001  // so don't need extra processing
1002  if constexpr(AQUANT_Pipeline)
1003  {
1004  return FlatmmPipeline{}(
1005  a_gather_block_tile,
1006  b_block_window,
1007  a_scale_block_window, // weight scale with granularityK = 32
1008  b_scale_block_window, // weight scale with granularityK = 32
1009  num_loop,
1010  smem_ptr_ping,
1011  smem_ptr_pong);
1012  }
1013  else
1014  {
1015  return FlatmmPipeline{}(
1016  a_gather_block_tile,
1017  b_block_window,
1018  b_scale_block_window, // weight scale with granularityK = 32
1019  num_loop,
1020  kargs.k_padded_zeros,
1021  smem_ptr_ping,
1022  smem_ptr_pong);
1023  }
1024  }
1025  else
1026  {
1027  return FlatmmPipeline{}(a_gather_block_tile,
1028  b_block_window,
1029  number<IsGateUp>{},
1030  num_loop,
1031  smem_ptr_ping,
1032  smem_ptr_pong);
1033  }
1034  }();
1035 
1036  auto& c_block_window = gemm_tile_windows.at(number<2>{});
1037 
1038  // Run EpiloguePipeline
1039  {
1040  using EpiProblem = typename EpiloguePipeline::Problem;
1041  using ODataType = typename EpiloguePipeline::ODataType;
1042  using CWarpDstr = typename EpiloguePipeline::CWarpDstr;
1043 
1044  constexpr index_t NumMXdlPerWavePerShuffle = EpiloguePipeline::NumMXdlPerWavePerShuffle;
1045  constexpr index_t NumNXdlPerWavePerShuffle = EpiloguePipeline::NumNXdlPerWavePerShuffle;
1046  constexpr index_t MPerIterationShuffle = EpiloguePipeline::MPerIterationShuffle;
1047  constexpr index_t NPerIterationShuffle = EpiloguePipeline::NPerIterationShuffle;
1048 
1049  constexpr index_t MRepeat = EpiloguePipeline::MRepeat;
1050  constexpr index_t NRepeat = EpiloguePipeline::NRepeat;
1051  constexpr index_t OutputNRepeat = IsGateUp ? NRepeat / 2 : NRepeat;
1052 
1053  [[maybe_unused]] constexpr index_t EpiVectorSizeC = EpiloguePipeline::GetVectorSizeC();
1054  [[maybe_unused]] constexpr index_t BlockedXDLN_PerWarp =
1055  EpiloguePipeline::BlockedXDLN_PerWarp;
1056 
1057  static_assert(!IsGateUp || NumNXdlPerWavePerShuffle % 2 == 0);
1058 
1059  constexpr index_t OutputNumNXdlPerWavePerShuffle =
1060  IsGateUp ? NumNXdlPerWavePerShuffle / 2 : NumNXdlPerWavePerShuffle;
1061  constexpr index_t LDS_NPerIterationShuffle =
1062  IsGateUp ? NPerIterationShuffle / 2 : NPerIterationShuffle;
1063 
1064  constexpr auto lds_block_desc = make_naive_tensor_descriptor(
1067 
1068  // EpiloguePipeline::template MakeLdsBlockDescriptor<EpiProblem>();
1069  auto o_lds_block = make_tensor_view<address_space_enum::lds>(
1070  reinterpret_cast<ODataType*>(smem_ptr_ping), lds_block_desc);
1071 
1072  constexpr int ScaleGranularityM = decltype(kargs.scale_m)::GranularityMN;
1073  constexpr int ScaleGranularityN = decltype(kargs.scale_n)::GranularityMN;
1074 
1075  constexpr index_t scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
1076  : 1; // per-token scale
1077  constexpr index_t scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
1078  : 1; // per-channel scale
1079 
1080  auto output_acc_tile_distr =
1083  sequence<>,
1088  sequence<0, 0>>{},
1089  typename CWarpDstr::DstrEncode{}));
1090 
1091  const auto scale_m_coord =
1092  output_acc_tile_distr.calculate_index(); // 2d thread offset, [i_row, i_col]
1093 
1094  constexpr index_t kM2 = 4; // Val-dim
1095  constexpr index_t kM1 = get_warp_size() / NPerXdl; // Thr-dim
1096  constexpr index_t kM0 = MPerXdl / kM1 / kM2; // Var-dim
1097 
1098  constexpr index_t ScaleMRepeat = MRepeat * kM0 * kM2;
1100 
1101  if constexpr(!BMXFP4_Pipeline)
1102  static_for<0, MRepeat, 1>{}([&](auto mIter) {
1103  static_for<0, kM0, 1>{}([&](auto m0) {
1104  static_for<0, kM2, 1>{}([&](auto m2) {
1105  const auto row_idx =
1106  coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0];
1107  scale_m_offsets[mIter * number<kM0 * kM2>{} + m0 * number<kM2>{} + m2] =
1108  row_to_token_idx(row_idx);
1109  });
1110  });
1111  });
1112 
1113  constexpr int DynamicTileOffsetFlag = 0;
1114 
1115  constexpr bool EnableBias = decltype(kargs.exp_bias)::GranularityMN != -1;
1116 
1117  auto permute_tensor_view = [&](auto naive_view, auto is_needed_to_permute_N_PACK) {
1118  if constexpr(!is_needed_to_permute_N_PACK)
1119  {
1120  return naive_view;
1121  }
1122  else
1123  {
1124  auto view1 = transform_tensor_view(
1125  naive_view,
1126  make_tuple(
1129  number<NRepeat / N_Pack>{},
1130  number<NWave>{},
1131  number<N_Pack>{},
1132  number<NPerXdl>{}))),
1135  return transform_tensor_view(
1136  view1,
1140  number<NRepeat / N_Pack>{},
1141  number<N_Pack>{},
1142  number<NWave>{},
1143  number<NPerXdl>{}))),
1146  }
1147  };
1148 
1149  auto scale_m_window =
1150  make_tile_scatter_gather(make_naive_tensor_view<address_space_enum::global>(
1151  kargs.scale_m.ptr,
1152  make_tuple(kargs.M, 1),
1153  make_tuple(scale_stride_m, 0),
1154  number<1>{}, // gather load can't vectorize
1155  number<1>{}),
1158  {0, 0}, // offset m is included in gather offsets
1159  output_acc_tile_distr,
1160  scale_m_offsets);
1161 
1162  auto scale_n_window = make_tile_window(
1163  make_naive_tensor_view<address_space_enum::global>(
1164  kargs.scale_n.ptr + expert_id * kargs.N,
1165  make_tuple(1, kargs.N),
1166  make_tuple(0, scale_stride_n),
1167  number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
1168  number<1>{}), // MXF4_Pipeline does't use scale_n, so there is no need to
1169  // permute as n_pack
1171  number < IsGateUp ? TilePartitioner::NPerBlock / 2
1172  : TilePartitioner::NPerBlock > {}),
1173  {0, IsGateUp ? coord_n / 2 : coord_n},
1174  output_acc_tile_distr);
1175 
1176  auto scale_n_up_window = make_tile_window(
1177  make_naive_tensor_view<address_space_enum::global>(
1178  kargs.scale_n.ptr + expert_id * kargs.N + kargs.N / 2,
1179  make_tuple(1, kargs.N),
1180  make_tuple(0, scale_stride_n),
1181  number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
1182  number<1>{}),
1184  number<TilePartitioner::NPerBlock / 2>{}),
1185  {0, coord_n / 2},
1186  output_acc_tile_distr);
1187 
1188  auto exp_bias_view = make_naive_tensor_view<address_space_enum::global>(
1189  kargs.exp_bias.ptr + expert_id * kargs.N,
1190  make_tuple(1, kargs.N),
1191  make_tuple(0, scale_stride_n),
1192  number<FlatmmPipeline::GetVectorSizeB()>{},
1193  number<1>{});
1194 
1195  auto exp_bias_window = make_tile_window(
1196  permute_tensor_view(exp_bias_view, number<(BMXFP4_Pipeline && !IsInputGemm)>{}),
1198  number < IsGateUp ? TilePartitioner::NPerBlock / 2
1199  : TilePartitioner::NPerBlock > {}),
1200  {0, IsGateUp ? coord_n / 2 : coord_n},
1201  output_acc_tile_distr);
1202 
1203  auto exp_bias_up_window =
1204  make_tile_window(make_naive_tensor_view<address_space_enum::global>(
1205  kargs.exp_bias.ptr + expert_id * kargs.N + kargs.N / 2,
1206  make_tuple(1, kargs.N),
1207  make_tuple(0, scale_stride_n),
1208  number<FlatmmPipeline::GetVectorSizeB()>{},
1209  number<1>{}),
1211  number<TilePartitioner::NPerBlock / 2>{}),
1212  {0, coord_n / 2},
1213  output_acc_tile_distr);
1214 
1215  auto exp_weight_window =
1216  make_tile_window(make_naive_tensor_view<address_space_enum::global>(
1217  static_cast<const float*>(kargs.p_sorted_expert_weights),
1218  make_tuple(kargs.M, 1),
1219  make_tuple(1, 0),
1220  number<FlatmmPipeline::GetVectorSizeA()>{},
1221  number<1>{}),
1224  {coord_m, 0},
1225  output_acc_tile_distr);
1226 
1227  using ScaleMBuffer = decltype(load_tile(scale_m_window));
1228  using ScaleNBuffer = decltype(load_tile(scale_n_window));
1229  using ExpBiasBuffer = decltype(load_tile(exp_bias_window));
1230  using ExpWeightBuffer = decltype(load_tile(exp_weight_window));
1231 
1232  ScaleMBuffer scale_m_buffer;
1233  ScaleNBuffer scale_n_buffer, scale_n_up_buffer;
1234 
1235  ExpBiasBuffer exp_bias_buffer, exp_bias_up_buffer;
1236  ExpWeightBuffer exp_weight_buffer;
1237 
1238  if constexpr(!BMXFP4_Pipeline)
1239  {
1240  scale_m_window.load(scale_m_buffer);
1241  scale_n_buffer = load_tile(scale_n_window);
1242  if constexpr(IsGateUp)
1243  scale_n_up_buffer = load_tile(scale_n_up_window);
1244  }
1245 
1246  if constexpr(EnableBias)
1247  {
1248  exp_bias_buffer = load_tile(exp_bias_window);
1249  if constexpr(IsGateUp)
1250  exp_bias_up_buffer = load_tile(exp_bias_up_window);
1251  }
1252  if constexpr(!IsInputGemm)
1253  exp_weight_buffer = load_tile(exp_weight_window);
1254 
1255  auto in_lds_window = make_tile_window(
1256  o_lds_block,
1258  {0, 0});
1259 
1260  auto out_lds_window = make_tile_window(
1261  o_lds_block,
1263  {0, 0});
1264 
1268 
1269  constexpr index_t num_access = SFC::get_num_of_access();
1270 
1271  static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
1272  "Currently, the CShuffle EpiloguePipeline only supports the Row Major "
1273  "Output layout");
1274 
1275  using TileEncodingPattern = tile_distribution_encoding_pattern_2d<
1276  kBlockSize,
1277  MPerIterationShuffle,
1278  LDS_NPerIterationShuffle,
1279  kind == MoeFlatmmKind::kFFN_gemm2 ? 2 : EpiloguePipeline::GetVectorSizeC(),
1281  EpiProblem::kNumWaveGroups>;
1282 
1283  constexpr auto dram_tile_distribution =
1284  TileEncodingPattern::make_2d_static_tile_distribution();
1285 
1286  constexpr auto LdsTileDistr = [&] {
1287  if constexpr(IsGateUp)
1291  sequence<>,
1293  // merge two contiguous N
1298  sequence<0, 0>>{},
1299  typename CWarpDstr::DstrEncode{}));
1300  else
1302  EpiloguePipeline::MakeLdsDistributionEncode());
1303  }();
1304 
1305  using LDSTileTensor =
1306  decltype(make_static_distributed_tensor<AccDataType>(LdsTileDistr));
1307  LDSTileTensor lds_tile[2];
1308 
1309  constexpr auto c_warp_y_lengths =
1310  to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
1311  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
1312  constexpr int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle *
1313  OutputNumNXdlPerWavePerShuffle;
1314 
1315  auto epi_tile_idx_slice =
1316  [&](const auto& acc_tile_like_tensor, auto epi_m_idx, auto epi_n_idx) {
1317  return acc_tile_like_tensor.get_y_sliced_thread_data(
1318  merge_sequences(sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
1319  epi_n_idx * OutputNumNXdlPerWavePerShuffle>{},
1320  c_warp_y_index_zeros),
1323  c_warp_y_lengths));
1324  };
1325 
1326  auto gate_up_epi_tile_idx_interleave_slice = [&](auto& dest_gate_tensor,
1327  auto& dest_up_tensor,
1328  const auto& acc_tile_like_tensor,
1329  auto epi_m_idx,
1330  auto epi_n_idx) {
1332  dest_gate_tensor.set_y_sliced_thread_data(
1333  merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
1335  acc_tile_like_tensor.get_y_sliced_thread_data(
1337  sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
1338  epi_n_idx * NumNXdlPerWavePerShuffle + 2 * n_xdl>{},
1339  c_warp_y_index_zeros),
1341  c_warp_y_lengths)));
1342  dest_up_tensor.set_y_sliced_thread_data(
1343  merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
1345  acc_tile_like_tensor.get_y_sliced_thread_data(
1347  sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
1348  epi_n_idx * NumNXdlPerWavePerShuffle + 2 * n_xdl + 1>{},
1349  c_warp_y_index_zeros),
1351  c_warp_y_lengths)));
1352  });
1353  };
1354 
1355  auto process_epi_tile = [&](auto lds_stage, auto epi_m, auto epi_n) {
1356  if constexpr(IsGateUp)
1357  {
1358  LDSTileTensor gate_tensor, up_tensor;
1359 
1360  gate_up_epi_tile_idx_interleave_slice(
1361  gate_tensor, up_tensor, c_block_tile, epi_m, epi_n);
1362  auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n);
1363  auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n);
1364  auto epi_scale_n_up = epi_tile_idx_slice(scale_n_up_buffer, epi_m, epi_n);
1365 
1366  auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
1367  auto epi_exp_bias_up = epi_tile_idx_slice(exp_bias_up_buffer, epi_m, epi_n);
1368 
1369  static_for<0, ActVectorSize, 1>{}([&](auto idx) {
1370  if constexpr(!BMXFP4_Pipeline)
1371  {
1372  gate_tensor.get_thread_buffer()[idx] *=
1373  epi_scale_m[idx] * epi_scale_n[idx];
1374  up_tensor.get_thread_buffer()[idx] *=
1375  epi_scale_m[idx] * epi_scale_n_up[idx];
1376  }
1377  if constexpr(EnableBias)
1378  {
1379  gate_tensor.get_thread_buffer()[idx] += epi_exp_bias[idx];
1380  up_tensor.get_thread_buffer()[idx] += epi_exp_bias_up[idx];
1381  }
1382  lds_tile[lds_stage].get_thread_buffer().at(idx) =
1383  ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
1384  up_tensor.get_thread_buffer().at(idx));
1385  });
1386  }
1387  else
1388  {
1389  lds_tile[lds_stage].get_thread_buffer() =
1390  epi_tile_idx_slice(c_block_tile, epi_m, epi_n);
1391  auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n);
1392  auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n);
1393  auto epi_exp_weight = epi_tile_idx_slice(exp_weight_buffer, epi_m, epi_n);
1394  auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
1395 
1396  static_for<0, ActVectorSize, 1>{}([&](auto idx) {
1397  if constexpr(!BMXFP4_Pipeline)
1398  lds_tile[lds_stage].get_thread_buffer()[idx] *=
1399  epi_scale_m[idx] * epi_scale_n[idx];
1400  if(kind !=
1401  MoeFlatmmKind::kFFN_gemm1_split_k) // disable weight and bias for split-k
1402  {
1403  if constexpr(EnableBias)
1404  lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
1405  if constexpr(!IsInputGemm)
1406  lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
1407  }
1408  if constexpr(kind ==
1409  MoeFlatmmKind::kFFN_gemm1_gate_only) // for mlp1 gate-only
1410  lds_tile[lds_stage].get_thread_buffer()[idx] =
1411  ActivationOp{}(lds_tile[lds_stage].get_thread_buffer()[idx]);
1412  });
1413  }
1414  };
1415 
1416  constexpr int NumMEpiTile = MRepeat / NumMXdlPerWavePerShuffle;
1417  constexpr int MPerThread = TileEncodingPattern::Y2;
1419  c_scatter_offsets;
1421  c_scatter_valids;
1422  auto c_coord = dram_tile_distribution.calculate_index();
1423  static_for<0, NumMEpiTile, 1>{}([&](auto mIter) {
1424  static_for<0, MPerThread, 1>{}([&](auto m0) {
1425  auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0;
1426  auto fused_token =
1427  kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
1428 
1429  index_t scatter_token_id = fused_token & token_id_mask;
1430  c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
1431  if constexpr(IsInputGemm)
1432  scatter_token_id =
1433  scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
1434  c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
1435  });
1436  });
1437 
1438  //===----------------------------------------------------------------------===//
1439  // Pingpong process start
1440  //===----------------------------------------------------------------------===//
1441  process_epi_tile(number<0>{}, number<0>{}, number<0>{});
1442 
1443  static_for<0, num_access, 1>{}([&](auto iAccess) {
1444  constexpr int read_stage = iAccess % 2;
1445  constexpr int write_stage = read_stage ^ 1;
1446 
1447  block_sync_lds();
1448  constexpr auto idx_y_start = SFC::get_index(number<iAccess.value>{});
1449  constexpr auto mIter = number<idx_y_start.at(number<0>{}) / MPerIterationShuffle>{};
1450 
1451  const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile[read_stage]);
1452 
1453  store_tile(in_lds_window, c_warptile_in_tensor_casted);
1454 
1455  if constexpr(iAccess < num_access - 1)
1456  {
1457  constexpr auto idx_y_start_next = SFC::get_index(number<iAccess.value + 1>{});
1458  constexpr auto mIter_next =
1459  number<idx_y_start_next.at(number<0>{}) / MPerIterationShuffle>{};
1460  constexpr auto nIter_next =
1461  number<idx_y_start_next.at(number<1>{}) / NPerIterationShuffle>{};
1462 
1463  process_epi_tile(number<write_stage>{}, mIter_next, nIter_next);
1464  }
1465 
1466  block_sync_lds();
1467 
1468  auto c_out_tensor =
1469  load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
1470  auto c_scatter_tile_window =
1471  make_tile_scatter_gather(c_block_window.get_bottom_tensor_view(),
1472  c_block_window.get_window_lengths(),
1473  c_block_window.get_window_origin(),
1474  dram_tile_distribution,
1475  c_scatter_offsets[mIter],
1476  c_scatter_valids[mIter]);
1477 
1478  if constexpr(!IsInputGemm ||
1479  decltype(c_block_window.get_bottom_tensor_view())::DstInMemOp ==
1481  c_scatter_tile_window.update(c_out_tensor);
1482  else
1483  c_scatter_tile_window.store(c_out_tensor);
1484 
1485  if constexpr(iAccess != num_access - 1)
1486  {
1487  constexpr auto step = SFC::get_forward_step(iAccess);
1488  // row_offset of out windows has been included in scatter offset
1489  move_tile_window(c_block_window,
1490  {0, step.at(number<1>{}) / number < IsGateUp ? 2 : 1 > {}});
1491  }
1492  });
1493  }
1494  }
1495 };
1496 
1497 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1690
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
@ thread_raked
Thread raked pattern.
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1697
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1684
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition: tensor_descriptor.hpp:371
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:422
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1066
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:837
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
MoeFlatmmKind
Definition: moe_flatmm_kernel.hpp:131
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:1086
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:486
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:206
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:495
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1037
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
index_t N
Definition: flatmm_kernel.hpp:173
const void * a_ptr
Definition: flatmm_kernel.hpp:164
index_t stride_B
Definition: flatmm_kernel.hpp:176
index_t stride_C
Definition: flatmm_kernel.hpp:181
index_t K
Definition: flatmm_kernel.hpp:174
const void * b_ptr
Definition: flatmm_kernel.hpp:165
index_t k_batch
Definition: flatmm_kernel.hpp:184
index_t stride_A
Definition: flatmm_kernel.hpp:175
void * e_ptr
Definition: flatmm_kernel.hpp:169
index_t M
Definition: flatmm_kernel.hpp:172
Definition: flatmm_kernel.hpp:33
Definition: moe_flatmm_kernel.hpp:21
ck_tile::index_t NumExperts
Definition: moe_flatmm_kernel.hpp:23
const void * p_sorted_expert_weights
Definition: moe_flatmm_kernel.hpp:28
const ck_tile::index_t * p_max_token_id
Definition: moe_flatmm_kernel.hpp:27
ck_tile::index_t NumTokens
Definition: moe_flatmm_kernel.hpp:22
const ck_tile::index_t * p_sorted_expert_ids
Definition: moe_flatmm_kernel.hpp:26
ExpertBias exp_bias
Definition: moe_flatmm_kernel.hpp:31
const ck_tile::index_t n_padded_zeros
Definition: moe_flatmm_kernel.hpp:29
const ck_tile::index_t * p_sorted_token_ids
Definition: moe_flatmm_kernel.hpp:25
const ck_tile::index_t k_padded_zeros
Definition: moe_flatmm_kernel.hpp:30
CK_TILE_HOST MoeFlatmmHostArgs(const ck_tile::index_t *p_sorted_token_ids_, const void *p_sorted_expert_weights_, const ck_tile::index_t *p_sorted_expert_ids_, const ck_tile::index_t *p_max_token_id_, const void *a_ptr_, const void *b_ptr_, void *c_ptr_, ck_tile::index_t NumTokens_, ck_tile::index_t NumExperts_, ck_tile::index_t TopK_, ck_tile::index_t k_batch_, ck_tile::index_t M_, ck_tile::index_t N_, ck_tile::index_t K_, ck_tile::index_t stride_A_, ck_tile::index_t stride_B_, ck_tile::index_t stride_C_, ck_tile::index_t n_padded_zeros_=0, ck_tile::index_t k_padded_zeros_=0, ScaleM scale_m_={}, ScaleN scale_n_={}, ExpertBias exp_bias_={})
Definition: moe_flatmm_kernel.hpp:80
CK_TILE_HOST MoeFlatmmHostArgs() noexcept=default
ck_tile::index_t TopK
Definition: moe_flatmm_kernel.hpp:24
Definition: moe_flatmm_kernel.hpp:273
ck_tile::index_t K
Definition: moe_flatmm_kernel.hpp:285
ExpertBias exp_bias
Definition: moe_flatmm_kernel.hpp:294
ck_tile::index_t stride_B
Definition: moe_flatmm_kernel.hpp:287
ScaleM scale_m
Definition: moe_flatmm_kernel.hpp:292
ck_tile::index_t k_padded_zeros
Definition: moe_flatmm_kernel.hpp:291
const void * b_ptr
Definition: moe_flatmm_kernel.hpp:279
ck_tile::index_t stride_A
Definition: moe_flatmm_kernel.hpp:286
ck_tile::index_t k_batch
Definition: moe_flatmm_kernel.hpp:289
ck_tile::index_t stride_C
Definition: moe_flatmm_kernel.hpp:288
void * e_ptr
Definition: moe_flatmm_kernel.hpp:280
const ck_tile::index_t * p_max_token_id
Definition: moe_flatmm_kernel.hpp:276
ScaleN scale_n
Definition: moe_flatmm_kernel.hpp:293
ck_tile::index_t NumTokens
Definition: moe_flatmm_kernel.hpp:281
ck_tile::index_t M
Definition: moe_flatmm_kernel.hpp:283
ck_tile::index_t n_padded_zeros
Definition: moe_flatmm_kernel.hpp:290
ck_tile::index_t TopK
Definition: moe_flatmm_kernel.hpp:282
const ck_tile::index_t * p_sorted_token_ids
Definition: moe_flatmm_kernel.hpp:274
const ck_tile::index_t * p_sorted_expert_ids
Definition: moe_flatmm_kernel.hpp:275
const void * a_ptr
Definition: moe_flatmm_kernel.hpp:278
ck_tile::index_t N
Definition: moe_flatmm_kernel.hpp:284
const void * p_sorted_expert_weights
Definition: moe_flatmm_kernel.hpp:277
Definition: moe_flatmm_kernel.hpp:384
index_t splitted_k
Definition: moe_flatmm_kernel.hpp:429
index_t b_k_split_offset
Definition: moe_flatmm_kernel.hpp:428
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: moe_flatmm_kernel.hpp:386
index_t a_k_split_offset
Definition: moe_flatmm_kernel.hpp:427
Definition: moe_flatmm_kernel.hpp:193
static constexpr int OutputNPerBlock
Definition: moe_flatmm_kernel.hpp:244
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: moe_flatmm_kernel.hpp:210
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: moe_flatmm_kernel.hpp:203
static constexpr index_t NumDTensor
Definition: moe_flatmm_kernel.hpp:215
static constexpr bool AQUANT_Pipeline
Definition: moe_flatmm_kernel.hpp:248
float AccDataType
Definition: moe_flatmm_kernel.hpp:212
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: moe_flatmm_kernel.hpp:197
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: moe_flatmm_kernel.hpp:202
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: moe_flatmm_kernel.hpp:378
static constexpr auto GridSize(const MoeFlatmmKernelArgs &kargs)
Definition: moe_flatmm_kernel.hpp:339
static constexpr auto I1
Definition: moe_flatmm_kernel.hpp:218
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: moe_flatmm_kernel.hpp:194
static constexpr bool BMXFP4_Pipeline
Definition: moe_flatmm_kernel.hpp:251
static constexpr auto I3
Definition: moe_flatmm_kernel.hpp:220
static constexpr index_t kBlockSize
Definition: moe_flatmm_kernel.hpp:204
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: moe_flatmm_kernel.hpp:374
static constexpr bool IsInputGemm
Definition: moe_flatmm_kernel.hpp:226
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: moe_flatmm_kernel.hpp:778
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: moe_flatmm_kernel.hpp:199
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: moe_flatmm_kernel.hpp:198
static constexpr int MXFP4N_Pack
Definition: moe_flatmm_kernel.hpp:260
static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: moe_flatmm_kernel.hpp:334
static constexpr bool UsePersistentKernel
Definition: moe_flatmm_kernel.hpp:205
FusedActivation ActivationOp
Definition: moe_flatmm_kernel.hpp:213
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: moe_flatmm_kernel.hpp:208
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: moe_flatmm_kernel.hpp:200
static constexpr bool IsBShuffled
Definition: moe_flatmm_kernel.hpp:229
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: moe_flatmm_kernel.hpp:207
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: moe_flatmm_kernel.hpp:201
static constexpr index_t kMPerBlock
Definition: moe_flatmm_kernel.hpp:232
static constexpr index_t MWave
Definition: moe_flatmm_kernel.hpp:234
static constexpr index_t KPerXdl
Definition: moe_flatmm_kernel.hpp:238
static constexpr auto BlockSize() -> dim3
Definition: moe_flatmm_kernel.hpp:332
static constexpr bool IsGateUp
Definition: moe_flatmm_kernel.hpp:227
static constexpr index_t kNPerBlock
Definition: moe_flatmm_kernel.hpp:233
static CK_TILE_HOST const std::string GetName()
Definition: moe_flatmm_kernel.hpp:326
static constexpr CK_TILE_HOST auto MakeKernelArgs(const MoeFlatmmHostArgs< ScaleM, ScaleN, ExpertBias > &hostArgs)
Definition: moe_flatmm_kernel.hpp:301
static constexpr index_t NPerXdl
Definition: moe_flatmm_kernel.hpp:237
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, EDataType *e_ptr, [[maybe_unused]] const AccDataType *exp_weight_ptr, [[maybe_unused]] const int expert_id, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: moe_flatmm_kernel.hpp:591
static constexpr index_t kNPerIteration
Definition: moe_flatmm_kernel.hpp:241
static constexpr index_t kMPerIteration
Definition: moe_flatmm_kernel.hpp:240
static constexpr int WeightPackedSize
Definition: moe_flatmm_kernel.hpp:267
static constexpr auto I0
Definition: moe_flatmm_kernel.hpp:217
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: moe_flatmm_kernel.hpp:433
static constexpr index_t isCTransposed
Definition: moe_flatmm_kernel.hpp:239
static constexpr bool IsGemm1SplitK
Definition: moe_flatmm_kernel.hpp:228
static constexpr int MXFP4M_Pack
Definition: moe_flatmm_kernel.hpp:259
static constexpr int K_Pack
Definition: moe_flatmm_kernel.hpp:265
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const
Definition: moe_flatmm_kernel.hpp:902
static constexpr int N_Pack
Definition: moe_flatmm_kernel.hpp:264
static constexpr bool MXF8F6F4MFMA
Definition: moe_flatmm_kernel.hpp:253
static constexpr int MXFP4K_Pack
Definition: moe_flatmm_kernel.hpp:261
static constexpr int M_Pack
Definition: moe_flatmm_kernel.hpp:263
static constexpr index_t kNRepeat
Definition: moe_flatmm_kernel.hpp:242
static constexpr index_t MPerXdl
Definition: moe_flatmm_kernel.hpp:236
static constexpr auto I4
Definition: moe_flatmm_kernel.hpp:221
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs, index_t iM, index_t iN) const
Definition: moe_flatmm_kernel.hpp:917
static constexpr index_t NWave
Definition: moe_flatmm_kernel.hpp:235
static constexpr auto I2
Definition: moe_flatmm_kernel.hpp:219
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: moe_flatmm_kernel.hpp:195
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, [[maybe_unused]] const index_t coord_m, const index_t coord_n)
Definition: moe_flatmm_kernel.hpp:821
Definition: flatmm_kernel.hpp:190
ScaleM scale_m
Definition: flatmm_kernel.hpp:222
ScaleN scale_n
Definition: flatmm_kernel.hpp:223
Definition: integral_constant.hpp:13
Definition: unary_element_wise_operation.hpp:1026
Definition: type_traits.hpp:115
Definition: moe_flatmm_kernel.hpp:141
CK_TILE_HOST_DEVICE T operator()(T gate, T linear=1) const
Definition: moe_flatmm_kernel.hpp:143
Definition: moe_flatmm_kernel.hpp:151
const float alpha
Definition: moe_flatmm_kernel.hpp:152
const float limit
Definition: moe_flatmm_kernel.hpp:153
CK_TILE_HOST_DEVICE Swiglu(float alpha_=1.702f, float limit_=7.0f)
Definition: moe_flatmm_kernel.hpp:156
CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
Definition: moe_flatmm_kernel.hpp:162
Definition: numeric.hpp:81
Definition: sequence.hpp:49
Definition: space_filling_curve.hpp:20
Definition: functional.hpp:43
Class creating 2D static tile distribution with different load/store patterns.
Definition: static_encoding_pattern.hpp:130
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192