/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp Source File
gridwise_moe_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
18 
19 #define DEBUG_LOG 0
20 
21 namespace ck {
22 
23 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24 // kernel function Blockers:
25 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26 // two lds chunks.
27 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28 // buffer when we declare __shared__ inside blkgemmpipe
29 
31 {
33  silu_and_mul = 1
34 };
35 
36 template <typename GridwiseGemm,
37  bool HasMainKBlockLoop,
38  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
39  index_t MinimumOccupancy = 1,
40  TailNumber TailNum = TailNumber::Even>
41 __global__ void
42 #if CK_USE_LAUNCH_BOUNDS
43 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
44 #endif
45  // __attribute__((amdgpu_waves_per_eu(1, 1)))
46  kernel_moe_gemm(typename GridwiseGemm::Argument karg)
47 {
48 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
49  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
50  {
51  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52 
53  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54 
55  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56  karg.p_sorted_token_ids,
57  karg.p_sorted_expert_ids,
58  karg.p_max_token_id,
59  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
61  karg.p_ds_grid,
62  karg.p_c_grid,
63  p_shared,
64  karg,
65  karg.a_element_op,
66  karg.b_element_op,
67  karg.c_element_op);
68  }
69 #else
70  ignore = karg;
71 #endif // end of if (defined(__gfx9__))
72 }
73 
74 template <typename GridwiseGemm,
75  bool HasMainKBlockLoop,
76  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
77  index_t MinimumOccupancy = 1,
78  TailNumber TailNum = TailNumber::Even>
79 __global__ void
80 #if CK_USE_LAUNCH_BOUNDS
81 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
82 #endif
83  // __attribute__((amdgpu_waves_per_eu(1, 1)))
84  kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
85 {
86 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
87  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
88  {
89  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
90  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
91 
92  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
93 
94  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
95  karg.p_sorted_token_ids,
96  karg.p_sorted_expert_ids,
97  karg.p_max_token_id,
98  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
99  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
100  karg.p_ds_grid,
101  karg.p_c_grid,
102  p_shared,
103  p_shared1,
104  karg,
105  karg.a_element_op,
106  karg.b_element_op,
107  karg.c_element_op);
108  }
109 #else
110  ignore = karg;
111 #endif // end of if (defined(__gfx9__))
112 }
113 
114 template <typename ALayout,
115  typename BLayout,
116  typename DsLayout,
117  typename CLayout,
118  typename ADataType,
119  typename BDataType,
120  typename AccDataType,
121  typename CShuffleDataType,
122  typename DsDataType,
123  typename CDataType,
124  typename AElementwiseOperation,
125  typename BElementwiseOperation,
126  typename CElementwiseOperation,
128  index_t BlockSize,
129  index_t MPerBlock,
130  index_t NPerBlock,
131  index_t KPerBlock,
132  index_t AK1Value,
133  index_t BK1Value,
134  index_t MPerXdl,
135  index_t NPerXdl,
136  index_t MXdlPerWave,
137  index_t NXdlPerWave,
138  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
139  typename ABlockTransferThreadClusterArrangeOrder,
140  typename ABlockTransferSrcAccessOrder,
141  index_t ABlockTransferSrcVectorDim,
142  index_t ABlockTransferSrcScalarPerVector,
143  index_t ABlockTransferDstScalarPerVector_AK1,
144  bool AThreadTransferSrcResetCoordinateAfterRun,
145  index_t ABlockLdsExtraM,
146  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
147  typename BBlockTransferThreadClusterArrangeOrder,
148  typename BBlockTransferSrcAccessOrder,
149  index_t BBlockTransferSrcVectorDim,
150  index_t BBlockTransferSrcScalarPerVector,
151  index_t BBlockTransferDstScalarPerVector_BK1,
152  bool BThreadTransferSrcResetCoordinateAfterRun,
153  index_t BBlockLdsExtraN,
154  index_t CShuffleMXdlPerWavePerShuffle,
155  index_t CShuffleNXdlPerWavePerShuffle,
156  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
157  typename CDEShuffleBlockTransferScalarPerVectors,
160  index_t ActivationOperation = 0,
161  bool NSwizzle = false,
162  bool IsInputGemm = true,
163  bool MulRoutedWeight = true,
164  bool PerTokenQuant = false,
165  typename IndexType = index_t,
166  typename ComputeTypeA = CDataType,
167  typename ComputeTypeB = ComputeTypeA,
168  typename LDSTypeA = ADataType,
169  typename LDSTypeB = BDataType>
171 {
172  static constexpr auto I0 = Number<0>{};
173  static constexpr auto I1 = Number<1>{};
174  static constexpr auto I2 = Number<2>{};
175  static constexpr auto I3 = Number<3>{};
176  static constexpr auto I4 = Number<4>{};
177  static constexpr auto I5 = Number<5>{};
178  static constexpr auto I6 = Number<6>{};
179  static constexpr auto I7 = Number<7>{};
180 
182  CDEShuffleBlockTransferScalarPerVectors{}[I0];
183  // K1 should be Number<...>
184  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
185  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
186  static constexpr auto AK1Number = Number<AK1Value>{};
187  static constexpr auto BK1Number = Number<BK1Value>{};
188  static constexpr auto BlockSizeNumber = Number<BlockSize>{};
189 
190  static constexpr index_t NumDTensor = DsDataType::Size();
191 
193  static constexpr index_t KPack =
195  static constexpr index_t KLane =
197 
198  static constexpr index_t KGroup = []() {
200  // On gfx950, we have a mfma that required 32 f8 elements as input,
201  // splited into 2 groups of 16 f8 elements.
202  // the 2 groups is not contiguous in the B preshuffed layout.
203  // and we do not want it to be contiguous in the B preshuffled layout
204  // because a memory instruction can only read 16 f8 elements at a time.
205  return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
206  else
207  return 1;
208  }();
209 
210  static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
211 
212  static constexpr index_t NLane = NPerXdl;
213  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
214  // static constexpr index_t NumTokens = 1;
215  static constexpr index_t SortedTileSize = MPerBlock;
216 
217  static constexpr auto MakeDsGridPointer()
218  {
219  return generate_tuple(
220  [&](auto i) {
221  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
222 
223  return static_cast<const DDataType*>(nullptr);
224  },
226  }
227 
228  using DsGridPointer = decltype(MakeDsGridPointer());
229 
231 
232  static constexpr index_t APackedSize = []() {
234  return 2;
235  else
236  return 1;
237  }();
238 
239  static constexpr index_t BPackedSize = []() {
241  return 2;
242  else
243  return 1;
244  }();
245 
246  __host__ static auto CalculateGridSize(index_t M, index_t N)
247  {
248  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
249  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
250  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
251  const index_t gridy = NSwizzle ? 1 : mblock;
252 
253  return std::make_tuple(gridx, gridy, 1);
254  }
255 
256  __host__ __device__ static auto CalculateMPadded(index_t M)
257  {
258  return math::integer_least_multiple(M, MPerBlock);
259  }
260 
261  __host__ __device__ static auto CalculateNPadded(index_t N)
262  {
263  return math::integer_least_multiple(N, NPerBlock);
264  }
265 
266  __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
267  {
268  return math::integer_divide_ceil(N, NLane);
269  }
270  __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
271  {
273  }
274 
275  __host__ __device__ static auto CalculateKPadded(index_t K)
276  {
277  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
278  }
279 
280  __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
281  {
282  auto K_t = K_Batch * KPerBlock;
283  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
284  }
285 
286  __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
287  {
288  auto K_t = K_Batch * KPerBlock;
289  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
290  }
291 
292  __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
293  {
294  auto K_t = K_Batch * KPerBlock;
295  return (K + K_t - 1) / K_t * KPerBlock;
296  }
297 
298  __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
299  {
300  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
301  auto K_t = K_Batch * KReadVec;
302  return (K + K_t - 1) / K_t * KReadVec;
303  }
304 
305  __host__ __device__ static auto CalculateMBlock(index_t M)
306  {
307  return math::integer_divide_ceil(M, MPerBlock);
308  }
309 
310  __host__ __device__ static auto CalculateNBlock(index_t N)
311  {
312  return math::integer_divide_ceil(N, NPerBlock);
313  }
314 
315  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
316  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
317  {
318  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
319  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
320 
322  TileDesc_K0_MN_K1{},
328  }
329 
330  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
331  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
332  {
333  const auto a_grid_desc_mraw_kraw = [&]() {
334  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
335  {
336  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
337  }
338  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
339  {
340  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
341  }
342  }();
343 
345 
346  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
347  GemmSpec == GemmSpecialization::MNKPadding)
348  {
349  // pad both M and K
350  const auto a_grid_desc_m_k =
351  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
353  make_right_pad_transform(K, KPad - K)),
356 
357  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
358  a_grid_desc_m_k,
363 
364  return a_grid_desc_ak0_m_ak1;
365  }
366  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
367  GemmSpec == GemmSpecialization::MNPadding)
368  {
369  // pad M, but not K
370  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
371  a_grid_desc_mraw_kraw,
373  make_right_pad_transform(M, MPad - M)),
376 
377  return a_grid_desc_ak0_m_ak1;
378  }
379  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
380  GemmSpec == GemmSpecialization::NKPadding)
381  {
382  // pad K, but not M
383  const auto a_grid_desc_m_k = transform_tensor_descriptor(
384  a_grid_desc_mraw_kraw,
388 
389  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
390  a_grid_desc_m_k,
395 
396  return a_grid_desc_ak0_m_ak1;
397  }
398  else
399  {
400  // not pad M or K
401  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
402  a_grid_desc_mraw_kraw,
407 
408  return a_grid_desc_ak0_m_ak1;
409  }
410  }
411 
412  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
413  {
414  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
415  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
416  constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack / KGroup>{};
418  make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
419  make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
420  }
421 
422  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
423  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
424  {
425  const auto b_grid_desc_nraw_kraw = [&]() {
427  {
428  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
429  }
431  {
432  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
433  }
434  }();
435 
437 
438  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
439  GemmSpec != GemmSpecialization::Default),
440  "pk_i4_t does not support padding");
441 
442  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
443  GemmSpec == GemmSpecialization::MNKPadding)
444  {
445  // pad both N and K
446  const auto b_grid_desc_n_k =
447  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
449  make_right_pad_transform(K, KPad - K)),
452 
453  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
454  b_grid_desc_n_k,
459 
460  return b_grid_desc_bk0_n_bk1;
461  }
462  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
463  GemmSpec == GemmSpecialization::MNPadding)
464  {
465  // pad N, but not K
466  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
467  b_grid_desc_nraw_kraw,
469  make_right_pad_transform(N, NPad - N)),
472 
473  return b_grid_desc_bk0_n_bk1;
474  }
475  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
476  GemmSpec == GemmSpecialization::MKPadding)
477  {
478  // pad K, but not N
479  const auto b_grid_desc_n_k = transform_tensor_descriptor(
480  b_grid_desc_nraw_kraw,
484 
485  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
486  b_grid_desc_n_k,
491 
492  return b_grid_desc_bk0_n_bk1;
493  }
494  else
495  {
496  // not pad N or K
497  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
498  b_grid_desc_nraw_kraw,
503 
504  return b_grid_desc_bk0_n_bk1;
505  }
506  }
507 
508  template <typename ABlockDesc_AK0_M_AK1>
509  __host__ __device__ static constexpr auto
510  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
511  {
512  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
513 
514  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
515  }
516 
517  template <typename BBlockDesc_BK0_N_BK1>
518  __host__ __device__ static constexpr auto
519  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
520  {
521  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
522  }
523 
524  template <typename ELayout>
525  __host__ __device__ static auto MakeCGridDescriptor_M_N(
526  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
527  {
528  const auto c_grid_desc_mraw_nraw = [&]() {
530  {
531  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
532  }
534  {
535  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
536  }
537  }();
538 
539  // pad M and N
540  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
542  make_right_pad_transform(N, NPad - N)),
545  }
546 
547  template <typename DLayout>
548  __host__ __device__ static auto
550  {
551  const auto c_grid_desc_mraw_nraw = [&]() {
553  {
554  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
555  }
557  {
558  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
559  }
560  }();
561 
562  // pad M and N
563  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
565  make_right_pad_transform(N, NPad - N)),
568  }
569 
570  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
571  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
572  {
573  return generate_tuple(
574  [&](auto i) {
575  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
576  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
577  },
579  }
580 
581  template <typename DsGridDesc>
583  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
584  {
585  return generate_tuple(
586  [&](auto i) {
588  ds_grid_desc_m_n[i], MBlock, NBlock);
589  },
591  }
592 
593  struct Problem
594  {
595  __host__ __device__ Problem(index_t NumTokens_,
596  index_t TopK_,
597  index_t M_,
598  index_t N_,
599  index_t K_,
600  index_t StrideA_,
601  index_t StrideB_,
602  std::array<index_t, NumDTensor> StrideDs_,
603  index_t StrideC_,
604  index_t KBatch_)
605  : NumTokens{NumTokens_},
606  TopK{TopK_},
607  M{M_},
608  N{N_},
609  K{K_},
610  StrideA{StrideA_},
611  StrideB{StrideB_},
612  StrideDs{StrideDs_},
613  StrideC{StrideC_},
614  KBatch{KBatch_},
617  KRead{CalculateKRead(K_, KBatch_)},
618  KPadded{CalculateKPadded(K_, KBatch_)},
619  AK0{CalculateAK0Padded(K_, KBatch_)},
620  BK0{CalculateBK0Padded(K_, KBatch_)},
621  MBlock{CalculateMBlock(M_)},
623  {
624  }
625 
626  __host__ void Print() const
627  {
628  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
629  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
630  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
631  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
632  << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
633  << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
634  << "NBlock: " << NBlock << "}" << std::endl;
635  }
636 
644  std::array<index_t, NumDTensor> StrideDs;
655  };
656 
657  // Argument
659  {
660  __host__ Argument(const index_t* p_sorted_token_ids_,
661  const index_t* p_sorted_expert_ids_,
662  const index_t* p_max_token_id_,
663  const ADataType* p_a_grid_,
664  const BDataType* p_b_grid_,
665  std::array<const void*, NumDTensor> p_ds_grid_,
666  CDataType* p_c_grid_,
667  index_t NumTokens_,
668  index_t TopK_,
669  index_t M_,
670  index_t N_,
671  index_t K_,
672  index_t StrideA_,
673  index_t StrideB_,
674  std::array<index_t, NumDTensor> StrideDs_,
675  index_t StrideC_,
676  index_t k_batch_,
677  AElementwiseOperation a_element_op_,
678  BElementwiseOperation b_element_op_,
679  CElementwiseOperation c_element_op_)
680  : Problem{NumTokens_,
681  TopK_,
682  M_,
683  N_,
684  K_,
685  StrideA_,
686  StrideB_,
687  StrideDs_,
688  StrideC_,
689  k_batch_},
690  p_sorted_token_ids{p_sorted_token_ids_},
691  p_sorted_expert_ids{p_sorted_expert_ids_},
692  p_max_token_id{p_max_token_id_},
693  p_a_grid{p_a_grid_},
694  p_b_grid{p_b_grid_},
695  p_ds_grid{},
696  p_c_grid{p_c_grid_},
697  a_element_op{a_element_op_},
698  b_element_op{b_element_op_},
699  c_element_op{c_element_op_}
700  {
701 
702  // populate pointer, desc for Ds
703  static_for<0, NumDTensor, 1>{}([&](auto i) {
704  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
705 
706  // D pointer
707  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
708  });
709  }
710 
714  const ADataType* p_a_grid;
715  const BDataType* p_b_grid;
717  CDataType* p_c_grid;
718 
719  const AElementwiseOperation a_element_op;
720  const BElementwiseOperation b_element_op;
721  const CElementwiseOperation c_element_op;
722  };
723 
725  {
726  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
727  {
728  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
729  {
730  a_k_split_offset = k_id * karg.KRead / APackedSize;
731  }
732  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
733  {
734  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
735  }
736 
737  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
738  {
739  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
740  }
741  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
742  {
743  // KPack * NLane * KLane * K0 * N0
744  b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
745  }
746 
747  if(k_id < karg.KBatch - 1)
748  {
749  karg.K = karg.KRead;
750  }
751  else
752  {
753  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
754  }
755  }
756 
759  };
760 
761  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
762  {
763  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
764  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
765 
766  // A matrix in LDS memory, dst of blockwise copy
767  if constexpr(ABlockLdsExtraM)
768  {
772  }
773  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
774  // in some cases.
776  {
777  constexpr auto a_lds_block_desc =
780 
781  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
782  a_lds_block_desc,
788 
789  return a_lds_block_desc_permuted;
790  }
791  else // ColumnMajor A
792  {
793  // kfold and mpair dimension is not always required.
794  // more dimension in merge_transform increase the difficulty of generating immarg offset
795  // for compiler.
796  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
797  constexpr auto M1 = MPerBlock / M0;
798 
799  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
800  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
801  constexpr auto KThreadRead = WaveSize / MPerXdl;
802  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
803 
804  constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
805  ? 1
806  : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
807  constexpr auto KThreadReadPerm =
808  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
809  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
810  : KThreadRead;
811 
812  // 1<=mpair<=n0
813  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
814  ? 1
815  : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
816  ? M0
817  : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
818 
819  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
823  Number<kfold * M0 / mpair>{},
824  Number<mpair>{},
825  AK1Number));
826 
827  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
828  a_lds_block_desc,
829  make_tuple(
833  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
836  make_tuple(
838  make_tuple(
840 
841  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
842  a_lds_block_desc_permuted,
843  make_tuple(
851  Sequence<1>{},
852  Sequence<2>{},
853  Sequence<3>{},
854  Sequence<4>{},
855  Sequence<5>{}),
857  Sequence<2>{},
858  Sequence<0, 3>{},
859  Sequence<4, 5>{},
860  Sequence<6>{},
861  Sequence<7>{}));
862 
863  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
864  a_lds_block_desc_unmerged,
867  Number<KThreadWrite / kfold / KThreadReadPerm>{},
868  Number<kfold>{},
875 
876  return a_lds_block_desc_ak0_m_ak1;
877  }
878  }
879 
880  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
881  {
882  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
885  }
886 
888  {
889  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
890 
891  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
893  make_tuple(I1,
895  I1,
897 
898  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
899  }
900 
903  BlkGemmPipelineVer,
904  BlkGemmPipeSched,
905  BlockSize,
906  ADataType,
907  BDataType,
908  ComputeTypeA,
909  AccDataType,
916  ABlockTransferSrcScalarPerVector,
917  BBlockTransferSrcScalarPerVector,
918  MPerBlock,
919  NPerBlock,
920  KPerBlock,
921  MPerXdl,
922  NPerXdl,
923  MXdlPerWave,
924  NXdlPerWave,
925  KPack,
926  IsInputGemm>())>;
927 
928  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
929  {
930  // LDS allocation for A and B: be careful of alignment
931  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
932  // lds max alignment
933  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
934 
935  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
936  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
937 
938  // LDS allocation for C shuffle in LDS
939  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
941 
942  constexpr auto c_block_size =
943  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
944 
945  return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize,
946  c_block_size * sizeof(CShuffleDataType));
947  }
948 
950 
951  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
952  __host__ static constexpr bool CheckValidity(const Argument& karg)
953  {
954  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
955  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
956  "Invalid tuning param!");
957 
958  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
963  {
964  if(!(karg.M % MPerBlock == 0))
965  {
966 #if DEBUG_LOG
967  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
968  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
969  << std::endl;
970 
971 #endif // DEBUG_LOG
972  return false;
973  }
974  }
975 
976  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
981  {
982  if(!(karg.N % NPerBlock == 0))
983  {
984 #if DEBUG_LOG
985  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
986  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
987  << std::endl;
988 
989 #endif // DEBUG_LOG
990  return false;
991  }
992  }
993 
994  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
998  {
999 
1000  auto K_t = karg.KBatch * KPerBlock;
1001  if(!(karg.K % K_t == 0))
1002  {
1003 #if DEBUG_LOG
1004  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1005  << karg.K << " " << __FILE__ << ":" << __LINE__
1006  << ", in function: " << __func__ << std::endl;
1007 
1008 #endif // DEBUG_LOG
1009  return false;
1010  }
1011  }
1012  else
1013  {
1014  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1015  auto K_t = karg.KBatch * KReadVec;
1016  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1017  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1018  {
1019  return false;
1020  }
1021  }
1022 
1024  {
1025  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1026  {
1027 #if DEBUG_LOG
1028  std::cout << "Arg K (" << karg.K
1029  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1030  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1031  << __LINE__ << ", in function: " << __func__ << std::endl;
1032 
1033 #endif // DEBUG_LOG
1034  return false;
1035  }
1036  }
1037  else
1038  {
1039  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1040  {
1041 #if DEBUG_LOG
1042  std::cout << "Arg M (" << karg.M
1043  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1044  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1045  << __LINE__ << ", in function: " << __func__ << std::endl;
1046 
1047 #endif // DEBUG_LOG
1048  return false;
1049  }
1050  }
1051 
1053  {
1054  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1055  {
1056 #if DEBUG_LOG
1057  std::cout << "Arg N (" << karg.N
1058  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1059  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1060  << __LINE__ << ", in function: " << __func__ << std::endl;
1061 
1062 #endif // DEBUG_LOG
1063  return false;
1064  }
1065  }
1066  else
1067  {
1068  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1069  {
1070 #if DEBUG_LOG
1071  std::cout << "Arg K (" << karg.K
1072  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1073  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1074  << __LINE__ << ", in function: " << __func__ << std::endl;
1075 
1076 #endif // DEBUG_LOG
1077  return false;
1078  }
1079  }
1080 
1082  {
1084  {
1085 #if DEBUG_LOG
1086  std::cout << "Arg N (" << karg.N
1087  << ") value is not a multiple of "
1088  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1089  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1090  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1091 
1092 #endif // DEBUG_LOG
1093  return false;
1094  }
1095  }
1096  else
1097  {
1099  {
1100 #if DEBUG_LOG
1101  std::cout << "Arg M (" << karg.M
1102  << ") value is not a multiple of "
1103  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1104  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1105  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1106 
1107 #endif // DEBUG_LOG
1108  return false;
1109  }
1110  }
1111 
1112  // check gridwise gemm pipeline
1113 #if 0
1114  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1115 
1116  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1117  {
1118  return false;
1119  }
1120 #endif
1121  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1122  return true;
1123  }
1124 
1125  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1126  {
1127  const index_t num_loop = K / KPerBlock;
1128 
1129  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1130  }
1131 
1132  __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1133  {
1134  const index_t num_loop = K / KPerBlock;
1135 
1136  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1137  }
1138 
1139  template <typename CGridDesc>
1141  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1142  {
1143  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1144  c_grid_desc_m_n,
1149 
1150  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1151  }
1152 
1153  // return block_id to C matrix tile idx (m0, n0) mapping
1154  // if arch = gfx942
1155  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1156  // NPerBlock>;
1157 
1158  template <bool HasMainKBlockLoop,
1159  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1160  TailNumber TailNum = TailNumber::Odd>
1161  __device__ static void Run(const index_t* p_sorted_token_ids,
1162  const index_t* p_sorted_expert_ids,
1163  const index_t* p_max_token_id,
1164  const ADataType* p_a_grid,
1165  const BDataType* p_b_grid,
1166  DsGridPointer& p_ds_grid,
1167  CDataType* p_c_grid,
1168  void* p_shared,
1169  const Problem& problem,
1170  AElementwiseOperation a_element_op,
1171  BElementwiseOperation b_element_op,
1172  CElementwiseOperation c_element_op)
1173  {
1174  ignore = b_element_op;
1175  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1176  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1177  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1178  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1179  problem.MPadded,
1180  problem.K,
1181  problem.KPadded,
1182  problem.StrideA,
1183  problem.AK0);
1184  const auto b_grid_desc_bpreshuffled =
1185  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1186  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1187  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1188  problem.MPadded,
1189  problem.N,
1190  problem.NPadded,
1191  problem.StrideC);
1192  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1194  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1195  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1196  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1197  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1198  if(expert_block_id * MPerBlock >= max_token_id)
1199  return;
1200  const index_t expert_id =
1201  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1202  const auto block_mn = [&]() -> std::pair<int, int> {
1203  if constexpr(NSwizzle)
1204  {
1205  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1206  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1207  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1208  const index_t expert_swizzle =
1209  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1210  const index_t bid_new = blockIdx.x - prefix_block;
1211  const index_t nid = __builtin_amdgcn_readfirstlane(
1212  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1213  const index_t mid =
1214  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1215  return {nid, mid};
1216  }
1217  else
1218  {
1219  return {blockIdx.x, blockIdx.y};
1220  }
1221  }();
1222 
1223  const index_t block_n_id = block_mn.first;
1224  const index_t block_m_id = block_mn.second;
1225  const index_t token0 =
1226  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1227 
1228  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1229  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1230  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1231  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1232  constexpr auto AKThreads = AK0Threads * AK1Threads;
1233  constexpr auto AMRepeats = MPerBlock / AMThreads;
1234  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1235 
1236  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1237  return;
1239  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1240  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1241  index_t token_offset = fused_token & 0xffffff;
1242  if constexpr(!IsInputGemm)
1243  {
1244  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1245  }
1246  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1247  });
1248  const IndexType expert_stride =
1249  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1250  const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
1251  // N0, K0, Blocksize*KPack
1252  const index_t n_block_data_idx_on_grid =
1253  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1254 
1255  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1256  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1257  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1258  p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1259  // A matrix in LDS memory, dst of blockwise copy
1260  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1261 
1262  // B matrix in LDS memory, dst of blockwise copy
1263  // dummy
1264  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1265  // A matrix blockwise copy
1266  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1268  AElementwiseOperation,
1272  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1273  ABlockTransferThreadClusterArrangeOrder,
1274  ADataType,
1275  LDSTypeA,
1276  decltype(a_grid_desc_ak0_m_ak1),
1277  decltype(a_block_desc_ak0_m_ak1),
1278  ABlockTransferSrcAccessOrder,
1280  ABlockTransferSrcVectorDim,
1281  2,
1282  ABlockTransferSrcScalarPerVector,
1283  ABlockTransferDstScalarPerVector_AK1,
1284  1,
1285  1,
1286  AThreadTransferSrcResetCoordinateAfterRun,
1287  true,
1288  IndexType,
1289  1,
1290  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1291  make_multi_index(0, 0, 0),
1292  a_element_op,
1293  a_block_desc_ak0_m_ak1,
1294  make_multi_index(0, 0, 0),
1296  gather_offsets);
1297 
1298  // Thread-wise copy
1299  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1300  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1301  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1302 
1303  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1304  BDataType,
1305  BDataType,
1306  decltype(b_grid_desc_bpreshuffled),
1307  decltype(b_block_desc_bk0_n_bk1),
1310  3,
1311  BBlockTransferSrcScalarPerVector,
1312  BThreadTransferSrcResetCoordinateAfterRun,
1313  true>(b_grid_desc_bpreshuffled,
1314  make_multi_index(n_block_data_idx_on_grid,
1316  0,
1317  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1318 
1319  // LDS allocation for A and B: be careful of alignment
1320  // Cast after lds
1321  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1322  static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1323 
1324  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1325  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1326 
1327  // Blockwise GEMM pipeline
1328  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1329  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1330  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1331  decltype(c_thread_buf) c_thread_buf_up;
1332 
1334  float,
1335  c_thread_buf.num_of_v_,
1336  c_thread_buf.s_per_v,
1337  true>
1338  c_thread_buf_fp32;
1339 
1340  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1341  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1342  KPerBlock);
1343  if constexpr(IsInputGemm)
1344  {
1345  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1346  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1347  p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1348  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1349  BDataType,
1350  BDataType,
1351  decltype(b_grid_desc_bpreshuffled),
1352  decltype(b_block_desc_bk0_n_bk1),
1355  3,
1356  BBlockTransferSrcScalarPerVector,
1357  BThreadTransferSrcResetCoordinateAfterRun,
1358  true>(b_grid_desc_bpreshuffled,
1359  make_multi_index(n_block_data_idx_on_grid,
1361  0,
1362  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1363 
1364  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1365  a_grid_desc_ak0_m_ak1,
1366  a_block_desc_ak0_m_ak1,
1367  a_blockwise_copy,
1368  a_grid_buf,
1369  a_block_buf,
1370  a_block_slice_copy_step,
1371  b_grid_desc_bpreshuffled,
1372  b_blockwise_copy,
1373  b_blockwise_copy_up,
1374  b_grid_buf,
1375  b_grid_buf_up,
1376  b_block_buf,
1377  b_block_slice_copy_step,
1378  c_thread_buf,
1379  c_thread_buf_up,
1380  num_k_block_main_loop);
1381  }
1382  else
1383  {
1384  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1385  a_grid_desc_ak0_m_ak1,
1386  a_block_desc_ak0_m_ak1,
1387  a_blockwise_copy,
1388  a_grid_buf,
1389  a_block_buf,
1390  a_block_slice_copy_step,
1391  b_grid_desc_bpreshuffled,
1392  b_blockwise_copy,
1393  b_grid_buf,
1394  b_block_buf,
1395  b_block_slice_copy_step,
1396  c_thread_buf,
1397  num_k_block_main_loop);
1398  }
1399 
1400  // shuffle C and write out
1401  {
1402  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1403  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1404  "wrong!");
1405 
1406  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1407 
1408  // TODO: hacky, fix it!
1409  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1410  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1411 
1412  // TODO: hacky, fix it!
1413  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1414  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1415  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1416 
1417  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1418  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1419  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1420  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1421  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1422  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1423  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1424  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1425 
1426  // mul scales
1427  const float* p_sorted_weights_0 = p_ds_grid[I0];
1428  const float* p_scale_b = p_ds_grid[I1];
1429 
1430  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
1431  static_assert(M4 == 4 || M4 == 8);
1432  const index_t m1 = get_warp_local_1d_id() / NWave;
1433  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
1434 
1435  if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
1436  {
1437  if constexpr(PerTokenQuant)
1438  {
1439  constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
1440  p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
1441  get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
1442  }
1443  else
1444  {
1445  p_scale_b += expert_id;
1446  }
1447 
1448  vector_type<int32_t, M4> scale_token_ids;
1449  vector_type<float, M4> topk_weights;
1450  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1451  const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
1452  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1453  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1454  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1455  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1456  if constexpr(PerTokenQuant)
1457  {
1458  scale_token_ids =
1459  *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
1460  p_sorted_token_ids + m_pos);
1461  }
1462  if constexpr(MulRoutedWeight)
1463  {
1464  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1465  p_ds_grid[I2] + m_pos);
1466  }
1467  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1468  float scale_a = [&]() {
1469  if constexpr(PerTokenQuant)
1470  {
1471  index_t fused_token =
1472  scale_token_ids.template AsType<index_t>()[m4];
1473  const index_t token_offset = fused_token & 0xffffff;
1474  return token_offset < problem.NumTokens
1475  ? p_sorted_weights_0[IsInputGemm
1476  ? token_offset
1477  : token_offset *
1478  problem.TopK +
1479  (fused_token >>
1480  24)]
1481  : 0.0;
1482  }
1483  else
1484  {
1485  return p_sorted_weights_0[0];
1486  }
1487  }();
1488  constexpr index_t c_offset =
1489  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1490  make_tuple(m0, n0, m2 * M4 + m4));
1491  constexpr auto cidx = Number<c_offset>{};
1492  if constexpr(IsInputGemm) // gu fusion
1493  {
1494  if constexpr(ActivationOperation == Activation::silu_and_mul)
1495  {
1496  const float scale_up =
1497  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
1498  PerTokenQuant];
1499  float gate = scale_a * scale_b * c_thread_buf[cidx];
1500  float up = scale_a * scale_up * c_thread_buf_up[cidx];
1501  if constexpr(MulRoutedWeight)
1502  {
1503  gate = gate * topk_weights.template AsType<float>()[m4];
1504  up = up * topk_weights.template AsType<float>()[m4];
1505  }
1507  {
1508  gate *= 16;
1509  up *= 16;
1510  }
1512  c_thread_buf_fp32(cidx) = gate * up;
1513  }
1514  else if(ActivationOperation == Activation::gelu_and_mul)
1515  {
1516  const float scale_up =
1517  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
1518  PerTokenQuant];
1519  float gate = scale_a * scale_b * c_thread_buf[cidx];
1520  float up = scale_a * scale_up * c_thread_buf_up[cidx];
1521  if constexpr(MulRoutedWeight)
1522  {
1523  gate = gate * topk_weights.template AsType<float>()[m4];
1524  up = up * topk_weights.template AsType<float>()[m4];
1525  }
1527  {
1528  gate *= 16;
1529  up *= 16;
1530  }
1532  c_thread_buf_fp32(cidx) = gate * up;
1533  }
1534  }
1535  else
1536  {
1537  c_thread_buf_fp32(cidx) =
1538  scale_a * scale_b * c_thread_buf[cidx];
1539  if constexpr(MulRoutedWeight)
1540  {
1541  c_thread_buf_fp32(cidx) =
1542  c_thread_buf_fp32(cidx) *
1543  topk_weights.template AsType<float>()[m4];
1544  }
1545  }
1546  });
1547  });
1548  });
1549  });
1550  }
1551  else
1552  {
1553  vector_type<float, M4> topk_weights; // for gemm2 only
1554  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1555  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1556  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1557  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1558  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1559  if constexpr(MulRoutedWeight)
1560  {
1561  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1562  p_ds_grid[I2] + m_pos);
1563  }
1564  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1565  constexpr index_t c_offset =
1566  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1567  make_tuple(m0, n0, m2 * M4 + m4));
1568  constexpr auto cidx = Number<c_offset>{};
1569 
1570  if constexpr(IsInputGemm) // gu fusion
1571  {
1572  if constexpr(ActivationOperation == Activation::silu_and_mul)
1573  {
1574  float gate = c_thread_buf[cidx];
1575  float up = c_thread_buf_up[cidx];
1576  if constexpr(MulRoutedWeight)
1577  {
1578  gate = gate * topk_weights.template AsType<float>()[m4];
1579  up = up * topk_weights.template AsType<float>()[m4];
1580  }
1582  c_thread_buf_fp32(cidx) = gate * up;
1583  }
1584  else if(ActivationOperation == Activation::gelu_and_mul)
1585  {
1586  float gate = c_thread_buf[cidx];
1587  float up = c_thread_buf_up[cidx];
1588  if constexpr(MulRoutedWeight)
1589  {
1590  gate = gate * topk_weights.template AsType<float>()[m4];
1591  up = up * topk_weights.template AsType<float>()[m4];
1592  }
1594  c_thread_buf_fp32(cidx) = gate * up;
1595  }
1596  }
1597  else
1598  {
1599  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1600  if constexpr(MulRoutedWeight)
1601  {
1602  c_thread_buf_fp32(cidx) =
1603  topk_weights.template AsType<float>()[m4] *
1604  c_thread_buf_fp32[cidx];
1605  }
1606  }
1607  });
1608  });
1609  });
1610  });
1611  }
1612 
1613  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1615 
1616  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1617  static_cast<CShuffleDataType*>(p_shared),
1618  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1619 
1620  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1621  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1622  make_tuple(
1625  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1626  M1, // M1 = MWave
1627  M2, // M2 * M3 * M4 = MPerXdl
1628  M3,
1629  M4)),
1632  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1633  N1, // N1 = NWave
1634  N2))), // N2 = NPerXdl
1636  make_tuple(
1638 
1639  // calculate origin of thread output tensor on global memory
1640  // blockwise GEMM c matrix starting index
1641  const auto c_thread_mtx_on_block =
1642  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1643 
1644  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1645  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1646 
1647  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1649  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1651  make_tuple(Sequence<0>{}));
1652 
1653  const auto m_thread_data_on_block_idx =
1654  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1655  make_multi_index(m_thread_data_on_block));
1656 
1657  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1661  make_tuple(Sequence<0>{}));
1662 
1663  const auto n_thread_data_on_block_idx =
1664  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1665  make_multi_index(n_thread_data_on_block));
1666 
1667  // shuffle: threadwise copy C from VGPR to LDS
1668  auto c_thread_copy_vgpr_to_lds =
1670  CShuffleDataType,
1671  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1672  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1674  Sequence<CShuffleMXdlPerWavePerShuffle,
1675  CShuffleNXdlPerWavePerShuffle,
1676  I1,
1677  I1,
1678  M2,
1679  I1,
1680  M4,
1681  I1>,
1683  7,
1684  1,
1686  1,
1687  true>{
1688  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1689  make_multi_index(0,
1690  0,
1691  m_thread_data_on_block_idx[I1],
1692  n_thread_data_on_block_idx[I1],
1693  m_thread_data_on_block_idx[I2],
1694  m_thread_data_on_block_idx[I3],
1695  m_thread_data_on_block_idx[I4],
1696  n_thread_data_on_block_idx[I2]),
1698 
1699  using EDataType = CDataType;
1700 
1701  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1702  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1703 
1704  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1706  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1707 
1708  const auto ds_grid_buf = generate_tuple(
1709  [&](auto i) {
1710  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1711  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1712  },
1713  Number<NumDTensor>{});
1714 
1715  // tuple of reference to C/Ds tensor descriptors
1716  const auto c_ds_desc_refs = concat_tuple_of_reference(
1717  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1718  generate_tie([&](auto i) -> const auto& // return type should be reference
1719  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1720  Number<NumDTensor>{}));
1721 
1722  // tuple of reference to C/Ds tensor descriptors
1723  const auto c_ds_buf_refs = concat_tuple_of_reference(
1724  tie(c_shuffle_block_buf),
1725  generate_tie([&](auto i) -> const auto& // return type should be reference
1726  { return ds_grid_buf[i]; },
1727  Number<NumDTensor>{}));
1728 
1729  // tuple of starting index of C/Ds blockwise copy
1730  const auto idx_c_ds_block_begin =
1733  [&](auto) {
1734  return make_multi_index(block_m_id, 0, block_n_id, 0);
1735  // return make_multi_index(block_work_idx[I0], 0,
1736  // block_work_idx[I1], 0);
1737  },
1738  Number<NumDTensor>{}));
1739 
1740  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1741  c_grid_desc_mblock_mperblock_nblock_nperblock;
1742 
1743  using CDEBlockTransferCluster =
1744  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1745  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1746  constexpr index_t scatter_weight_idx = 3; // hack fix felix
1747  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1749  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1751  decltype(c_ds_desc_refs),
1752  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1753  CElementwiseOperation,
1754  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1755  // support arbitray type
1756  Sequence<1,
1757  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1758  1,
1759  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1760  CDEBlockTransferCluster,
1761  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1762  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1763  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1764  3, // index_t SrcVectorDim,
1765  3, // index_t DstVectorDim,
1766  CDEShuffleBlockTransferScalarPerVectors,
1771  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1772  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1773  IndexType,
1774  1, // ScatterDim
1775  true, // OutputScatter: false, only use scatter weights
1776  scatter_weight_idx // ScatterWeightIdx: ascale
1777  >{c_ds_desc_refs,
1778  idx_c_ds_block_begin,
1779  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1780  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1781  c_element_op};
1782 
1783  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1784  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1785  constexpr auto sfc_c_vgpr =
1788  Sequence<CShuffleMXdlPerWavePerShuffle,
1789  CShuffleNXdlPerWavePerShuffle,
1790  1,
1791  1,
1792  M2,
1793  1,
1794  M4,
1795  1>>{};
1796 
1797  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1798 
1799  // space filling curve for shuffled blockwise C/D/E
1800  constexpr auto sfc_cde_block =
1803  Sequence<1,
1804  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1805  1,
1806  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1807 
1808  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1809  constexpr auto EMThreads =
1810  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1811  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1812  constexpr auto ENThreads =
1813  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1814  static_for<0, num_access, 1>{}([&](auto access_id) {
1815  // make sure it's safe to write to LDS
1817 
1818  auto dstidx = sfc_cde_block.GetIndex(access_id);
1819  const index_t c_token_pos =
1820  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1821  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1822  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1823  IndexType token_offset = fused_token & 0xffffff;
1824  if constexpr(IsInputGemm)
1825  {
1826  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1827  }
1828  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
1829  });
1830 
1831  block_sync_lds();
1832 
1833  // each thread write its data from VGPR to LDS
1834  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1835  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1836  c_thread_buf_fp32,
1837  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1838  c_shuffle_block_buf);
1839 
1840  // make sure it's safe to read from LDS
1841  block_sync_lds();
1842 
1843  // each block copy its data from LDS to global
1844  cde_block_copy_lds_and_global.Run(
1845  c_ds_desc_refs,
1846  c_ds_buf_refs,
1847  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1848  tie(c_grid_buf),
1849  scatter_offsets);
1850 
1851  if constexpr(access_id < num_access - 1)
1852  {
1853  constexpr auto cde_lds_and_global_step =
1854  sfc_cde_block.GetForwardStep(access_id);
1855 
1856  // move on Ds
1857  static_for<0, NumDTensor, 1>{}([&](auto i) {
1858  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1859  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1860  });
1861 
1862  // move on E
1863  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1864  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1865  I0,
1866  cde_lds_and_global_step);
1867  }
1868  });
1869  }
1870  }
1871 
1872  template <bool HasMainKBlockLoop,
1873  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1874  TailNumber TailNum = TailNumber::Odd>
1875  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1876  const index_t* p_sorted_expert_ids,
1877  const index_t* p_max_token_id,
1878  const ADataType* p_a_grid,
1879  const BDataType* p_b_grid,
1880  DsGridPointer& p_ds_grid,
1881  CDataType* p_c_grid,
1882  void* p_shared,
1883  void* p_shared1,
1884  const Problem& problem,
1885  AElementwiseOperation a_element_op,
1886  BElementwiseOperation b_element_op,
1887  CElementwiseOperation c_element_op)
1888  {
1889  ignore = b_element_op;
1890  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1891  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1892  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1893  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1894  problem.MPadded,
1895  problem.K,
1896  problem.KPadded,
1897  problem.StrideA,
1898  problem.AK0);
1899  const auto b_grid_desc_bpreshuffled =
1900  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1901  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1902  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1903  problem.MPadded,
1904  problem.N,
1905  problem.NPadded,
1906  problem.StrideC);
1907  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1909  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1910  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1911  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1912  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1913  if(expert_block_id * MPerBlock >= max_token_id)
1914  return;
1915  const index_t expert_id =
1916  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1917  const auto block_mn = [&]() -> std::pair<int, int> {
1918  if constexpr(NSwizzle)
1919  {
1920  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1921  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1922  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1923  const index_t expert_swizzle =
1924  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1925  const index_t bid_new = blockIdx.x - prefix_block;
1926  const index_t nid = __builtin_amdgcn_readfirstlane(
1927  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1928  const index_t mid =
1929  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1930  return {nid, mid};
1931  }
1932  else
1933  {
1934  return {blockIdx.x, blockIdx.y};
1935  }
1936  }();
1937 
1938  const index_t block_n_id = block_mn.first;
1939  const index_t block_m_id = block_mn.second;
1940  const index_t token0 =
1941  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1942 
1943  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1944  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1945  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1946  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1947  constexpr auto AKThreads = AK0Threads * AK1Threads;
1948  constexpr auto AMRepeats = MPerBlock / AMThreads;
1949  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1950 
1951  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1952  return;
1954  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1955  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1956  index_t token_offset = fused_token & 0xffffff;
1957  if constexpr(!IsInputGemm)
1958  {
1959  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1960  }
1961  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1962  });
1963  const IndexType expert_stride =
1964  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1965  const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
1966  // N0, K0, Blocksize*KPack
1967  const index_t n_block_data_idx_on_grid =
1968  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1969 
1970  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1971  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1972  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1973  p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1974 
1975  // A matrix in LDS memory, dst of blockwise copy
1976  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1977 
1978  // B matrix in LDS memory, dst of blockwise copy
1979  // dummy
1980  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1981  // A matrix blockwise copy
1982  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1984  AElementwiseOperation,
1988  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1989  ABlockTransferThreadClusterArrangeOrder,
1990  ADataType,
1991  LDSTypeA,
1992  decltype(a_grid_desc_ak0_m_ak1),
1993  decltype(a_block_desc_ak0_m_ak1),
1994  ABlockTransferSrcAccessOrder,
1996  ABlockTransferSrcVectorDim,
1997  2,
1998  ABlockTransferSrcScalarPerVector,
1999  ABlockTransferDstScalarPerVector_AK1,
2000  1,
2001  1,
2002  AThreadTransferSrcResetCoordinateAfterRun,
2003  true,
2004  IndexType,
2005  1,
2006  2>(a_grid_desc_ak0_m_ak1,
2007  make_multi_index(0, 0, 0),
2008  a_element_op,
2009  a_block_desc_ak0_m_ak1,
2010  make_multi_index(0, 0, 0),
2012  gather_offsets);
2013 
2014  // Thread-wise copy
2015  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2016  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2017  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2018  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2019  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2020  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2021 
2022  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
2023  BDataType,
2024  BDataType,
2025  decltype(b_grid_desc_bpreshuffled),
2026  decltype(b_block_desc_bk0_n_bk1),
2029  3,
2030  BBlockTransferSrcScalarPerVector,
2031  BThreadTransferSrcResetCoordinateAfterRun,
2032  true>(b_grid_desc_bpreshuffled,
2033  make_multi_index(n_block_data_idx_on_grid,
2035  0,
2036  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2037 
2038  // LDS allocation for A and B: be careful of alignment
2039  // Cast after lds
2040  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2041  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2042  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2043  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2044  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2045 
2046  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2047  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
2048 
2049  // Blockwise GEMM pipeline
2050  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2051  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2052  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2053  decltype(c_thread_buf) c_thread_buf_up;
2054 
2056  float,
2057  c_thread_buf.num_of_v_,
2058  c_thread_buf.s_per_v,
2059  true>
2060  c_thread_buf_fp32;
2061 
2062  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2063  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2064  KPerBlock);
2065 
2066  if constexpr(IsInputGemm)
2067  {
2068  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2069  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2070  p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2071  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2072  BDataType,
2073  BDataType,
2074  decltype(b_grid_desc_bpreshuffled),
2075  decltype(b_block_desc_bk0_n_bk1),
2078  3,
2079  BBlockTransferSrcScalarPerVector,
2080  BThreadTransferSrcResetCoordinateAfterRun,
2081  true>(b_grid_desc_bpreshuffled,
2082  make_multi_index(n_block_data_idx_on_grid,
2084  0,
2085  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2086  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2087  a_grid_desc_ak0_m_ak1,
2088  a_block_desc_ak0_m_ak1,
2089  a_blockwise_copy,
2090  a_grid_buf,
2091  a_block_bufs,
2092  a_block_slice_copy_step,
2093  b_grid_desc_bpreshuffled,
2094  b_blockwise_copy,
2095  b_blockwise_copy_up,
2096  b_grid_buf,
2097  b_grid_buf_up,
2098  b_block_bufs,
2099  b_block_slice_copy_step,
2100  c_thread_buf,
2101  c_thread_buf_up,
2102  num_k_block_main_loop);
2103  }
2104  else
2105  {
2106 
2107  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2108  a_grid_desc_ak0_m_ak1,
2109  a_block_desc_ak0_m_ak1,
2110  a_blockwise_copy,
2111  a_grid_buf,
2112  a_block_bufs,
2113  a_block_slice_copy_step,
2114  b_grid_desc_bpreshuffled,
2115  b_blockwise_copy,
2116  b_grid_buf,
2117  b_block_bufs,
2118  b_block_slice_copy_step,
2119  c_thread_buf,
2120  num_k_block_main_loop);
2121  }
2122 
2123  // shuffle C and write out
2124  {
2125  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2126  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2127  "wrong!");
2128 
2129  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2130 
2131  // TODO: hacky, fix it!
2132  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2133  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2134 
2135  // TODO: hacky, fix it!
2136  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2137  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2138  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2139 
2140  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2141  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2142  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2143  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2144  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2145  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2146  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2147  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2148 
2149  // mul scales
2150  const float* p_sorted_weights_0 = p_ds_grid[I0];
2151  const float* p_scale_b = p_ds_grid[I1];
2152 
2153  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2154  static_assert(M4 == 4 || M4 == 8);
2155  const index_t m1 = get_warp_local_1d_id() / NWave;
2156  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
2157 
2158  if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
2159  {
2160  if constexpr(PerTokenQuant)
2161  {
2162  constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
2163  p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
2164  get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
2165  }
2166  else
2167  {
2168  p_scale_b += expert_id;
2169  }
2170 
2171  vector_type<int32_t, M4> scale_token_ids;
2172  vector_type<float, M4> topk_weights;
2173  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2174  const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
2175  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2176  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2177  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2178  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2179  if constexpr(PerTokenQuant)
2180  {
2181  scale_token_ids =
2182  *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
2183  p_sorted_token_ids + m_pos);
2184  }
2185  if constexpr(MulRoutedWeight)
2186  {
2187  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2188  p_ds_grid[I2] + m_pos);
2189  }
2190  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2191  float scale_a = [&]() {
2192  if constexpr(PerTokenQuant)
2193  {
2194  index_t fused_token =
2195  scale_token_ids.template AsType<index_t>()[m4];
2196  const index_t token_offset = fused_token & 0xffffff;
2197  return token_offset < problem.NumTokens
2198  ? p_sorted_weights_0[IsInputGemm
2199  ? token_offset
2200  : token_offset *
2201  problem.TopK +
2202  (fused_token >>
2203  24)]
2204  : 0.0;
2205  }
2206  else
2207  {
2208  return p_sorted_weights_0[0];
2209  }
2210  }();
2211  constexpr index_t c_offset =
2212  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2213  make_tuple(m0, n0, m2 * M4 + m4));
2214  constexpr auto cidx = Number<c_offset>{};
2215  if constexpr(IsInputGemm) // gu fusion
2216  {
2217  if constexpr(ActivationOperation == Activation::silu_and_mul)
2218  {
2219  const float scale_up =
2220  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
2221  PerTokenQuant];
2222  float gate = scale_a * scale_b * c_thread_buf[cidx];
2223  float up = scale_a * scale_up * c_thread_buf_up[cidx];
2224  if constexpr(MulRoutedWeight)
2225  {
2226  gate = gate * topk_weights.template AsType<float>()[m4];
2227  up = up * topk_weights.template AsType<float>()[m4];
2228  }
2230  {
2231  gate *= 16;
2232  up *= 16;
2233  }
2235  c_thread_buf_fp32(cidx) = gate * up;
2236  }
2237  else if(ActivationOperation == Activation::gelu_and_mul)
2238  {
2239  const float scale_up =
2240  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
2241  PerTokenQuant];
2242  float gate = scale_a * scale_b * c_thread_buf[cidx];
2243  float up = scale_a * scale_up * c_thread_buf_up[cidx];
2244  if constexpr(MulRoutedWeight)
2245  {
2246  gate = gate * topk_weights.template AsType<float>()[m4];
2247  up = up * topk_weights.template AsType<float>()[m4];
2248  }
2250  {
2251  gate *= 16;
2252  up *= 16;
2253  }
2255  c_thread_buf_fp32(cidx) = gate * up;
2256  }
2257  }
2258  else
2259  {
2260  c_thread_buf_fp32(cidx) =
2261  scale_a * scale_b * c_thread_buf[cidx];
2262  if constexpr(MulRoutedWeight)
2263  {
2264  c_thread_buf_fp32(cidx) =
2265  c_thread_buf_fp32(cidx) *
2266  topk_weights.template AsType<float>()[m4];
2267  }
2268  }
2269  });
2270  });
2271  });
2272  });
2273  }
2274  else
2275  {
2276  vector_type<float, M4> topk_weights; // for gemm2 only
2277  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2278  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2279  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2280  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2281  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2282  if constexpr(MulRoutedWeight)
2283  {
2284  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2285  p_ds_grid[I2] + m_pos);
2286  }
2287  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2288  constexpr index_t c_offset =
2289  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2290  make_tuple(m0, n0, m2 * M4 + m4));
2291  constexpr auto cidx = Number<c_offset>{};
2292 
2293  if constexpr(IsInputGemm) // gu fusion
2294  {
2295  if constexpr(ActivationOperation == Activation::silu_and_mul)
2296  {
2297  float gate = c_thread_buf[cidx];
2298  float up = c_thread_buf_up[cidx];
2299  if constexpr(MulRoutedWeight)
2300  {
2301  gate = gate * topk_weights.template AsType<float>()[m4];
2302  up = up * topk_weights.template AsType<float>()[m4];
2303  }
2305  c_thread_buf_fp32(cidx) = gate * up;
2306  }
2307  else if(ActivationOperation == Activation::gelu_and_mul)
2308  {
2309  float gate = c_thread_buf[cidx];
2310  float up = c_thread_buf_up[cidx];
2311  if constexpr(MulRoutedWeight)
2312  {
2313  gate = gate * topk_weights.template AsType<float>()[m4];
2314  up = up * topk_weights.template AsType<float>()[m4];
2315  }
2317  c_thread_buf_fp32(cidx) = gate * up;
2318  }
2319  }
2320  else
2321  {
2322  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2323  if constexpr(MulRoutedWeight)
2324  {
2325  c_thread_buf_fp32(cidx) =
2326  topk_weights.template AsType<float>()[m4] *
2327  c_thread_buf_fp32[cidx];
2328  }
2329  }
2330  });
2331  });
2332  });
2333  });
2334  }
2335 
2336  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2338 
2339  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2340  static_cast<CShuffleDataType*>(p_shared),
2341  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2342 
2343  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2344  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2345  make_tuple(
2348  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2349  M1, // M1 = MWave
2350  M2, // M2 * M3 * M4 = MPerXdl
2351  M3,
2352  M4)),
2355  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2356  N1, // N1 = NWave
2357  N2))), // N2 = NPerXdl
2359  make_tuple(
2361 
2362  // calculate origin of thread output tensor on global memory
2363  // blockwise GEMM c matrix starting index
2364  const auto c_thread_mtx_on_block =
2365  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2366 
2367  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2368  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2369 
2370  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2372  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2374  make_tuple(Sequence<0>{}));
2375 
2376  const auto m_thread_data_on_block_idx =
2377  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2378  make_multi_index(m_thread_data_on_block));
2379 
2380  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2384  make_tuple(Sequence<0>{}));
2385 
2386  const auto n_thread_data_on_block_idx =
2387  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2388  make_multi_index(n_thread_data_on_block));
2389 
2390  // shuffle: threadwise copy C from VGPR to LDS
2391  auto c_thread_copy_vgpr_to_lds =
2393  CShuffleDataType,
2394  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2395  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2397  Sequence<CShuffleMXdlPerWavePerShuffle,
2398  CShuffleNXdlPerWavePerShuffle,
2399  I1,
2400  I1,
2401  M2,
2402  I1,
2403  M4,
2404  I1>,
2406  7,
2407  1,
2409  1,
2410  true>{
2411  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2412  make_multi_index(0,
2413  0,
2414  m_thread_data_on_block_idx[I1],
2415  n_thread_data_on_block_idx[I1],
2416  m_thread_data_on_block_idx[I2],
2417  m_thread_data_on_block_idx[I3],
2418  m_thread_data_on_block_idx[I4],
2419  n_thread_data_on_block_idx[I2]),
2421 
2422  using EDataType = CDataType;
2423 
2424  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2425  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2426 
2427  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2429  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2430 
2431  const auto ds_grid_buf = generate_tuple(
2432  [&](auto i) {
2433  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2434  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2435  },
2436  Number<NumDTensor>{});
2437 
2438  // tuple of reference to C/Ds tensor descriptors
2439  const auto c_ds_desc_refs = concat_tuple_of_reference(
2440  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2441  generate_tie([&](auto i) -> const auto& // return type should be reference
2442  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2443  Number<NumDTensor>{}));
2444 
2445  // tuple of reference to C/Ds tensor descriptors
2446  const auto c_ds_buf_refs = concat_tuple_of_reference(
2447  tie(c_shuffle_block_buf),
2448  generate_tie([&](auto i) -> const auto& // return type should be reference
2449  { return ds_grid_buf[i]; },
2450  Number<NumDTensor>{}));
2451 
2452  // tuple of starting index of C/Ds blockwise copy
2453  const auto idx_c_ds_block_begin =
2456  [&](auto) {
2457  return make_multi_index(block_m_id, 0, block_n_id, 0);
2458  // return make_multi_index(block_work_idx[I0], 0,
2459  // block_work_idx[I1], 0);
2460  },
2461  Number<NumDTensor>{}));
2462 
2463  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2464  c_grid_desc_mblock_mperblock_nblock_nperblock;
2465 
2466  using CDEBlockTransferCluster =
2467  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2468  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2469  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2470  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2472  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2474  decltype(c_ds_desc_refs),
2475  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2476  CElementwiseOperation,
2477  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2478  // support arbitray type
2479  Sequence<1,
2480  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2481  1,
2482  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2483  CDEBlockTransferCluster,
2484  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2485  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2486  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2487  3, // index_t SrcVectorDim,
2488  3, // index_t DstVectorDim,
2489  CDEShuffleBlockTransferScalarPerVectors,
2494  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2495  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2496  IndexType,
2497  1, // ScatterDim
2498  true, // OutputScatter: false, only use scatter weights
2499  scatter_weight_idx // ScatterWeightIdx: ascale
2500  >{c_ds_desc_refs,
2501  idx_c_ds_block_begin,
2502  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2503  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2504  c_element_op};
2505 
2506  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2507  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2508  constexpr auto sfc_c_vgpr =
2511  Sequence<CShuffleMXdlPerWavePerShuffle,
2512  CShuffleNXdlPerWavePerShuffle,
2513  1,
2514  1,
2515  M2,
2516  1,
2517  M4,
2518  1>>{};
2519 
2520  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2521 
2522  // space filling curve for shuffled blockwise C/D/E
2523  constexpr auto sfc_cde_block =
2526  Sequence<1,
2527  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2528  1,
2529  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2530 
2531  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2532  constexpr auto EMThreads =
2533  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2534  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2535  constexpr auto ENThreads =
2536  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2537  static_for<0, num_access, 1>{}([&](auto access_id) {
2538  // make sure it's safe to write to LDS
2540 
2541  auto dstidx = sfc_cde_block.GetIndex(access_id);
2542  const index_t c_token_pos =
2543  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2544  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2545  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2546  IndexType token_offset = fused_token & 0xffffff;
2547  if constexpr(IsInputGemm)
2548  {
2549  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2550  }
2551  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2552  });
2553 
2554  block_sync_lds();
2555 
2556  // each thread write its data from VGPR to LDS
2557  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2558  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2559  c_thread_buf_fp32,
2560  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2561  c_shuffle_block_buf);
2562 
2563  // make sure it's safe to read from LDS
2564  block_sync_lds();
2565 
2566  // each block copy its data from LDS to global
2567  cde_block_copy_lds_and_global.Run(
2568  c_ds_desc_refs,
2569  c_ds_buf_refs,
2570  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2571  tie(c_grid_buf),
2572  scatter_offsets);
2573 
2574  if constexpr(access_id < num_access - 1)
2575  {
2576  constexpr auto cde_lds_and_global_step =
2577  sfc_cde_block.GetForwardStep(access_id);
2578 
2579  // move on Ds
2580  static_for<0, NumDTensor, 1>{}([&](auto i) {
2581  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2582  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2583  });
2584 
2585  // move on E
2586  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2587  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2588  I0,
2589  cde_lds_and_global_step);
2590  }
2591  });
2592  }
2593  }
2594 };
2595 
2596 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
constexpr auto BlockGemmBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp:41
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
Definition: gridwise_moe_gemm.hpp:659
const BDataType * p_b_grid
Definition: gridwise_moe_gemm.hpp:715
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm.hpp:711
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm.hpp:712
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm.hpp:719
const ADataType * p_a_grid
Definition: gridwise_moe_gemm.hpp:714
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm.hpp:660
const index_t * p_max_token_id
Definition: gridwise_moe_gemm.hpp:713
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm.hpp:720
CDataType * p_c_grid
Definition: gridwise_moe_gemm.hpp:717
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm.hpp:716
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm.hpp:721
Definition: gridwise_moe_gemm.hpp:594
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm.hpp:644
index_t NumTokens
Definition: gridwise_moe_gemm.hpp:637
index_t MBlock
Definition: gridwise_moe_gemm.hpp:653
index_t TopK
Definition: gridwise_moe_gemm.hpp:638
index_t K
Definition: gridwise_moe_gemm.hpp:641
__host__ void Print() const
Definition: gridwise_moe_gemm.hpp:626
index_t NPadded
Definition: gridwise_moe_gemm.hpp:648
index_t BK0
Definition: gridwise_moe_gemm.hpp:652
index_t KRead
Definition: gridwise_moe_gemm.hpp:649
index_t MPadded
Definition: gridwise_moe_gemm.hpp:647
index_t AK0
Definition: gridwise_moe_gemm.hpp:651
index_t StrideA
Definition: gridwise_moe_gemm.hpp:642
index_t StrideC
Definition: gridwise_moe_gemm.hpp:645
index_t M
Definition: gridwise_moe_gemm.hpp:639
index_t KBatch
Definition: gridwise_moe_gemm.hpp:646
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm.hpp:595
index_t KPadded
Definition: gridwise_moe_gemm.hpp:650
index_t StrideB
Definition: gridwise_moe_gemm.hpp:643
index_t N
Definition: gridwise_moe_gemm.hpp:640
index_t NBlock
Definition: gridwise_moe_gemm.hpp:654
Definition: gridwise_moe_gemm.hpp:725
index_t a_k_split_offset
Definition: gridwise_moe_gemm.hpp:757
index_t b_k_split_offset
Definition: gridwise_moe_gemm.hpp:758
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm.hpp:726
Definition: gridwise_moe_gemm.hpp:171
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm.hpp:246
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:298
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm.hpp:217
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:292
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm.hpp:210
remove_cvref_t< decltype(BlockGemmBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm.hpp:926
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm.hpp:261
static constexpr index_t NLane
Definition: gridwise_moe_gemm.hpp:212
static constexpr auto I5
Definition: gridwise_moe_gemm.hpp:177
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm.hpp:549
static constexpr auto BK0Number
Definition: gridwise_moe_gemm.hpp:185
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm.hpp:330
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm.hpp:190
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm.hpp:1132
static constexpr auto I2
Definition: gridwise_moe_gemm.hpp:174
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm.hpp:232
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm.hpp:305
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm.hpp:230
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm.hpp:412
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm.hpp:422
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm.hpp:519
static constexpr auto I6
Definition: gridwise_moe_gemm.hpp:178
static constexpr auto I0
Definition: gridwise_moe_gemm.hpp:172
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm.hpp:215
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm.hpp:1125
static constexpr auto I1
Definition: gridwise_moe_gemm.hpp:173
static constexpr auto I4
Definition: gridwise_moe_gemm.hpp:176
static constexpr auto AK1Number
Definition: gridwise_moe_gemm.hpp:186
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:280
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm.hpp:310
static constexpr auto BK1Number
Definition: gridwise_moe_gemm.hpp:187
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm.hpp:188
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm.hpp:239
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm.hpp:525
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm.hpp:270
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm.hpp:228
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1875
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:286
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm.hpp:570
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm.hpp:952
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm.hpp:887
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm.hpp:510
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm.hpp:181
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm.hpp:256
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:1140
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm.hpp:928
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm.hpp:880
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm.hpp:266
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1161
static constexpr index_t KPack
Definition: gridwise_moe_gemm.hpp:193
static constexpr index_t NWave
Definition: gridwise_moe_gemm.hpp:213
static constexpr auto I3
Definition: gridwise_moe_gemm.hpp:175
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm.hpp:275
static constexpr auto AK0Number
Definition: gridwise_moe_gemm.hpp:184
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:582
static constexpr index_t KGroup
Definition: gridwise_moe_gemm.hpp:198
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm.hpp:316
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm.hpp:761
static constexpr index_t KLane
Definition: gridwise_moe_gemm.hpp:195
static constexpr auto I7
Definition: gridwise_moe_gemm.hpp:179
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1810
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1804
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: amd_ck_fp8.hpp:36
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10