/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp Source File
gridwise_moe_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
18 
19 #define DEBUG_LOG 0
20 
21 namespace ck {
22 
23 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24 // kernel function Blockers:
25 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26 // two lds chunks.
27 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28 // buffer when we declare __shared__ inside blkgemmpipe
29 
31 {
33  silu_and_mul = 1
34 };
35 
36 template <typename GridwiseGemm,
37  bool HasMainKBlockLoop,
38  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
39  index_t MinimumOccupancy = 1,
40  TailNumber TailNum = TailNumber::Even>
41 __global__ void
42 #if CK_USE_LAUNCH_BOUNDS
43 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
44 #endif
45  // __attribute__((amdgpu_waves_per_eu(1, 1)))
46  kernel_moe_gemm(typename GridwiseGemm::Argument karg)
47 {
48 #if defined(__gfx9__)
49  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
50 
51  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
52 
53  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
54  karg.p_sorted_token_ids,
55  karg.p_sorted_expert_ids,
56  karg.p_max_token_id,
57  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
58  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
59  karg.p_ds_grid,
60  karg.p_c_grid,
61  p_shared,
62  karg,
63  karg.a_element_op,
64  karg.b_element_op,
65  karg.c_element_op);
66 #else
67  ignore = karg;
68 #endif // end of if (defined(__gfx9__))
69 }
70 
71 template <typename GridwiseGemm,
72  bool HasMainKBlockLoop,
73  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
74  index_t MinimumOccupancy = 1,
75  TailNumber TailNum = TailNumber::Even>
76 __global__ void
77 #if CK_USE_LAUNCH_BOUNDS
78 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
79 #endif
80  // __attribute__((amdgpu_waves_per_eu(1, 1)))
81  kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
82 {
83 #if defined(__gfx9__)
84  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
85  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
86 
87  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
88 
89  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
90  karg.p_sorted_token_ids,
91  karg.p_sorted_expert_ids,
92  karg.p_max_token_id,
93  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
94  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
95  karg.p_ds_grid,
96  karg.p_c_grid,
97  p_shared,
98  p_shared1,
99  karg,
100  karg.a_element_op,
101  karg.b_element_op,
102  karg.c_element_op);
103 #else
104  ignore = karg;
105 #endif // end of if (defined(__gfx9__))
106 }
107 
108 template <typename ALayout,
109  typename BLayout,
110  typename DsLayout,
111  typename CLayout,
112  typename ADataType,
113  typename BDataType,
114  typename AccDataType,
115  typename CShuffleDataType,
116  typename DsDataType,
117  typename CDataType,
118  typename AElementwiseOperation,
119  typename BElementwiseOperation,
120  typename CElementwiseOperation,
122  index_t BlockSize,
123  index_t MPerBlock,
124  index_t NPerBlock,
125  index_t KPerBlock,
126  index_t AK1Value,
127  index_t BK1Value,
128  index_t MPerXdl,
129  index_t NPerXdl,
130  index_t MXdlPerWave,
131  index_t NXdlPerWave,
132  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
133  typename ABlockTransferThreadClusterArrangeOrder,
134  typename ABlockTransferSrcAccessOrder,
135  index_t ABlockTransferSrcVectorDim,
136  index_t ABlockTransferSrcScalarPerVector,
137  index_t ABlockTransferDstScalarPerVector_AK1,
138  bool AThreadTransferSrcResetCoordinateAfterRun,
139  index_t ABlockLdsExtraM,
140  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
141  typename BBlockTransferThreadClusterArrangeOrder,
142  typename BBlockTransferSrcAccessOrder,
143  index_t BBlockTransferSrcVectorDim,
144  index_t BBlockTransferSrcScalarPerVector,
145  index_t BBlockTransferDstScalarPerVector_BK1,
146  bool BThreadTransferSrcResetCoordinateAfterRun,
147  index_t BBlockLdsExtraN,
148  index_t CShuffleMXdlPerWavePerShuffle,
149  index_t CShuffleNXdlPerWavePerShuffle,
150  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
151  typename CDEShuffleBlockTransferScalarPerVectors,
154  index_t ActivationOperation = 0,
155  bool NSwizzle = false,
156  bool IsInputGemm = true,
157  bool MulRoutedWeight = true,
158  bool PerTokenQuant = false,
159  typename IndexType = index_t,
160  typename ComputeTypeA = CDataType,
161  typename ComputeTypeB = ComputeTypeA,
162  typename LDSTypeA = ADataType,
163  typename LDSTypeB = BDataType>
165 {
166  static constexpr auto I0 = Number<0>{};
167  static constexpr auto I1 = Number<1>{};
168  static constexpr auto I2 = Number<2>{};
169  static constexpr auto I3 = Number<3>{};
170  static constexpr auto I4 = Number<4>{};
171  static constexpr auto I5 = Number<5>{};
172  static constexpr auto I6 = Number<6>{};
173  static constexpr auto I7 = Number<7>{};
174 
176  CDEShuffleBlockTransferScalarPerVectors{}[I0];
177  // K1 should be Number<...>
178  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
179  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
180  static constexpr auto AK1Number = Number<AK1Value>{};
181  static constexpr auto BK1Number = Number<BK1Value>{};
182  static constexpr auto BlockSizeNumber = Number<BlockSize>{};
183 
184  static constexpr index_t NumDTensor = DsDataType::Size();
185 
187  static constexpr index_t KPack =
189  static constexpr index_t KLane =
191 
192  static constexpr index_t KGroup = []() {
194  // On gfx950, we have a mfma that required 32 f8 elements as input,
195  // splited into 2 groups of 16 f8 elements.
196  // the 2 groups is not contiguous in the B preshuffed layout.
197  // and we do not want it to be contiguous in the B preshuffled layout
198  // because a memory instruction can only read 16 f8 elements at a time.
199  return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
200  else
201  return 1;
202  }();
203 
204  static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
205 
206  static constexpr index_t NLane = NPerXdl;
207  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
208  // static constexpr index_t NumTokens = 1;
209  static constexpr index_t SortedTileSize = MPerBlock;
210 
211  static constexpr auto MakeDsGridPointer()
212  {
213  return generate_tuple(
214  [&](auto i) {
215  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
216 
217  return static_cast<const DDataType*>(nullptr);
218  },
220  }
221 
222  using DsGridPointer = decltype(MakeDsGridPointer());
223 
225 
226  static constexpr index_t APackedSize = []() {
228  return 2;
229  else
230  return 1;
231  }();
232 
233  static constexpr index_t BPackedSize = []() {
235  return 2;
236  else
237  return 1;
238  }();
239 
240  __host__ static auto CalculateGridSize(index_t M, index_t N)
241  {
242  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
243  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
244  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
245  const index_t gridy = NSwizzle ? 1 : mblock;
246 
247  return std::make_tuple(gridx, gridy, 1);
248  }
249 
250  __host__ __device__ static auto CalculateMPadded(index_t M)
251  {
252  return math::integer_least_multiple(M, MPerBlock);
253  }
254 
255  __host__ __device__ static auto CalculateNPadded(index_t N)
256  {
257  return math::integer_least_multiple(N, NPerBlock);
258  }
259 
260  __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
261  {
262  return math::integer_divide_ceil(N, NLane);
263  }
264  __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
265  {
267  }
268 
269  __host__ __device__ static auto CalculateKPadded(index_t K)
270  {
271  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
272  }
273 
274  __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
275  {
276  auto K_t = K_Batch * KPerBlock;
277  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
278  }
279 
280  __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
281  {
282  auto K_t = K_Batch * KPerBlock;
283  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
284  }
285 
286  __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
287  {
288  auto K_t = K_Batch * KPerBlock;
289  return (K + K_t - 1) / K_t * KPerBlock;
290  }
291 
292  __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
293  {
294  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
295  auto K_t = K_Batch * KReadVec;
296  return (K + K_t - 1) / K_t * KReadVec;
297  }
298 
299  __host__ __device__ static auto CalculateMBlock(index_t M)
300  {
301  return math::integer_divide_ceil(M, MPerBlock);
302  }
303 
304  __host__ __device__ static auto CalculateNBlock(index_t N)
305  {
306  return math::integer_divide_ceil(N, NPerBlock);
307  }
308 
309  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
310  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
311  {
312  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
313  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
314 
316  TileDesc_K0_MN_K1{},
322  }
323 
324  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
325  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
326  {
327  const auto a_grid_desc_mraw_kraw = [&]() {
328  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
329  {
330  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
331  }
332  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
333  {
334  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
335  }
336  }();
337 
339 
340  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
341  GemmSpec == GemmSpecialization::MNKPadding)
342  {
343  // pad both M and K
344  const auto a_grid_desc_m_k =
345  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
347  make_right_pad_transform(K, KPad - K)),
350 
351  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
352  a_grid_desc_m_k,
357 
358  return a_grid_desc_ak0_m_ak1;
359  }
360  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
361  GemmSpec == GemmSpecialization::MNPadding)
362  {
363  // pad M, but not K
364  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
365  a_grid_desc_mraw_kraw,
367  make_right_pad_transform(M, MPad - M)),
370 
371  return a_grid_desc_ak0_m_ak1;
372  }
373  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
374  GemmSpec == GemmSpecialization::NKPadding)
375  {
376  // pad K, but not M
377  const auto a_grid_desc_m_k = transform_tensor_descriptor(
378  a_grid_desc_mraw_kraw,
382 
383  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
384  a_grid_desc_m_k,
389 
390  return a_grid_desc_ak0_m_ak1;
391  }
392  else
393  {
394  // not pad M or K
395  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
396  a_grid_desc_mraw_kraw,
401 
402  return a_grid_desc_ak0_m_ak1;
403  }
404  }
405 
406  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
407  {
408  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
409  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
410  constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack / KGroup>{};
412  make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
413  make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
414  }
415 
416  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
417  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
418  {
419  const auto b_grid_desc_nraw_kraw = [&]() {
421  {
422  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
423  }
425  {
426  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
427  }
428  }();
429 
431 
432  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
433  GemmSpec != GemmSpecialization::Default),
434  "pk_i4_t does not support padding");
435 
436  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
437  GemmSpec == GemmSpecialization::MNKPadding)
438  {
439  // pad both N and K
440  const auto b_grid_desc_n_k =
441  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
443  make_right_pad_transform(K, KPad - K)),
446 
447  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
448  b_grid_desc_n_k,
453 
454  return b_grid_desc_bk0_n_bk1;
455  }
456  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
457  GemmSpec == GemmSpecialization::MNPadding)
458  {
459  // pad N, but not K
460  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
461  b_grid_desc_nraw_kraw,
463  make_right_pad_transform(N, NPad - N)),
466 
467  return b_grid_desc_bk0_n_bk1;
468  }
469  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
470  GemmSpec == GemmSpecialization::MKPadding)
471  {
472  // pad K, but not N
473  const auto b_grid_desc_n_k = transform_tensor_descriptor(
474  b_grid_desc_nraw_kraw,
478 
479  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
480  b_grid_desc_n_k,
485 
486  return b_grid_desc_bk0_n_bk1;
487  }
488  else
489  {
490  // not pad N or K
491  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
492  b_grid_desc_nraw_kraw,
497 
498  return b_grid_desc_bk0_n_bk1;
499  }
500  }
501 
502  template <typename ABlockDesc_AK0_M_AK1>
503  __host__ __device__ static constexpr auto
504  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
505  {
506  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
507 
508  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
509  }
510 
511  template <typename BBlockDesc_BK0_N_BK1>
512  __host__ __device__ static constexpr auto
513  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
514  {
515  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
516  }
517 
518  template <typename ELayout>
519  __host__ __device__ static auto MakeCGridDescriptor_M_N(
520  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
521  {
522  const auto c_grid_desc_mraw_nraw = [&]() {
524  {
525  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
526  }
528  {
529  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
530  }
531  }();
532 
533  // pad M and N
534  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
536  make_right_pad_transform(N, NPad - N)),
539  }
540 
541  template <typename DLayout>
542  __host__ __device__ static auto
544  {
545  const auto c_grid_desc_mraw_nraw = [&]() {
547  {
548  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
549  }
551  {
552  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
553  }
554  }();
555 
556  // pad M and N
557  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
559  make_right_pad_transform(N, NPad - N)),
562  }
563 
564  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
565  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
566  {
567  return generate_tuple(
568  [&](auto i) {
569  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
570  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
571  },
573  }
574 
575  template <typename DsGridDesc>
577  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
578  {
579  return generate_tuple(
580  [&](auto i) {
582  ds_grid_desc_m_n[i], MBlock, NBlock);
583  },
585  }
586 
587  struct Problem
588  {
589  __host__ __device__ Problem(index_t NumTokens_,
590  index_t TopK_,
591  index_t M_,
592  index_t N_,
593  index_t K_,
594  index_t StrideA_,
595  index_t StrideB_,
596  std::array<index_t, NumDTensor> StrideDs_,
597  index_t StrideC_,
598  index_t KBatch_)
599  : NumTokens{NumTokens_},
600  TopK{TopK_},
601  M{M_},
602  N{N_},
603  K{K_},
604  StrideA{StrideA_},
605  StrideB{StrideB_},
606  StrideDs{StrideDs_},
607  StrideC{StrideC_},
608  KBatch{KBatch_},
611  KRead{CalculateKRead(K_, KBatch_)},
612  KPadded{CalculateKPadded(K_, KBatch_)},
613  AK0{CalculateAK0Padded(K_, KBatch_)},
614  BK0{CalculateBK0Padded(K_, KBatch_)},
615  MBlock{CalculateMBlock(M_)},
617  {
618  }
619 
620  __host__ void Print() const
621  {
622  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
623  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
624  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
625  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
626  << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
627  << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
628  << "NBlock: " << NBlock << "}" << std::endl;
629  }
630 
638  std::array<index_t, NumDTensor> StrideDs;
649  };
650 
651  // Argument
653  {
654  __host__ Argument(const index_t* p_sorted_token_ids_,
655  const index_t* p_sorted_expert_ids_,
656  const index_t* p_max_token_id_,
657  const ADataType* p_a_grid_,
658  const BDataType* p_b_grid_,
659  std::array<const void*, NumDTensor> p_ds_grid_,
660  CDataType* p_c_grid_,
661  index_t NumTokens_,
662  index_t TopK_,
663  index_t M_,
664  index_t N_,
665  index_t K_,
666  index_t StrideA_,
667  index_t StrideB_,
668  std::array<index_t, NumDTensor> StrideDs_,
669  index_t StrideC_,
670  index_t k_batch_,
671  AElementwiseOperation a_element_op_,
672  BElementwiseOperation b_element_op_,
673  CElementwiseOperation c_element_op_)
674  : Problem{NumTokens_,
675  TopK_,
676  M_,
677  N_,
678  K_,
679  StrideA_,
680  StrideB_,
681  StrideDs_,
682  StrideC_,
683  k_batch_},
684  p_sorted_token_ids{p_sorted_token_ids_},
685  p_sorted_expert_ids{p_sorted_expert_ids_},
686  p_max_token_id{p_max_token_id_},
687  p_a_grid{p_a_grid_},
688  p_b_grid{p_b_grid_},
689  p_ds_grid{},
690  p_c_grid{p_c_grid_},
691  a_element_op{a_element_op_},
692  b_element_op{b_element_op_},
693  c_element_op{c_element_op_}
694  {
695 
696  // populate pointer, desc for Ds
697  static_for<0, NumDTensor, 1>{}([&](auto i) {
698  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
699 
700  // D pointer
701  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
702  });
703  }
704 
708  const ADataType* p_a_grid;
709  const BDataType* p_b_grid;
711  CDataType* p_c_grid;
712 
713  const AElementwiseOperation a_element_op;
714  const BElementwiseOperation b_element_op;
715  const CElementwiseOperation c_element_op;
716  };
717 
719  {
720  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
721  {
722  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
723  {
724  a_k_split_offset = k_id * karg.KRead / APackedSize;
725  }
726  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
727  {
728  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
729  }
730 
731  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
732  {
733  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
734  }
735  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
736  {
737  // KPack * NLane * KLane * K0 * N0
738  b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
739  }
740 
741  if(k_id < karg.KBatch - 1)
742  {
743  karg.K = karg.KRead;
744  }
745  else
746  {
747  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
748  }
749  }
750 
753  };
754 
755  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
756  {
757  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
758  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
759 
760  // A matrix in LDS memory, dst of blockwise copy
761  if constexpr(ABlockLdsExtraM)
762  {
766  }
767  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
768  // in some cases.
770  {
771  constexpr auto a_lds_block_desc =
774 
775  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
776  a_lds_block_desc,
782 
783  return a_lds_block_desc_permuted;
784  }
785  else // ColumnMajor A
786  {
787  // kfold and mpair dimension is not always required.
788  // more dimension in merge_transform increase the difficulty of generating immarg offset
789  // for compiler.
790  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
791  constexpr auto M1 = MPerBlock / M0;
792 
793  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
794  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
795  constexpr auto KThreadRead = WaveSize / MPerXdl;
796  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
797 
798  constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
799  ? 1
800  : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
801  constexpr auto KThreadReadPerm =
802  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
803  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
804  : KThreadRead;
805 
806  // 1<=mpair<=n0
807  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
808  ? 1
809  : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
810  ? M0
811  : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
812 
813  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
817  Number<kfold * M0 / mpair>{},
818  Number<mpair>{},
819  AK1Number));
820 
821  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
822  a_lds_block_desc,
823  make_tuple(
827  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
830  make_tuple(
832  make_tuple(
834 
835  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
836  a_lds_block_desc_permuted,
837  make_tuple(
845  Sequence<1>{},
846  Sequence<2>{},
847  Sequence<3>{},
848  Sequence<4>{},
849  Sequence<5>{}),
851  Sequence<2>{},
852  Sequence<0, 3>{},
853  Sequence<4, 5>{},
854  Sequence<6>{},
855  Sequence<7>{}));
856 
857  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
858  a_lds_block_desc_unmerged,
861  Number<KThreadWrite / kfold / KThreadReadPerm>{},
862  Number<kfold>{},
869 
870  return a_lds_block_desc_ak0_m_ak1;
871  }
872  }
873 
874  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
875  {
876  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
879  }
880 
882  {
883  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
884 
885  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
887  make_tuple(I1,
889  I1,
891 
892  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
893  }
894 
897  BlkGemmPipelineVer,
898  BlkGemmPipeSched,
899  BlockSize,
900  ADataType,
901  BDataType,
902  ComputeTypeA,
903  AccDataType,
910  ABlockTransferSrcScalarPerVector,
911  BBlockTransferSrcScalarPerVector,
912  MPerBlock,
913  NPerBlock,
914  KPerBlock,
915  MPerXdl,
916  NPerXdl,
917  MXdlPerWave,
918  NXdlPerWave,
919  KPack,
920  IsInputGemm>())>;
921 
922  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
923  {
924  // LDS allocation for A and B: be careful of alignment
925  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
926  // lds max alignment
927  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
928 
929  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
930  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
931 
932  // LDS allocation for C shuffle in LDS
933  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
935 
936  constexpr auto c_block_size =
937  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
938 
939  return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize,
940  c_block_size * sizeof(CShuffleDataType));
941  }
942 
943  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
944  __host__ static constexpr bool CheckValidity(const Argument& karg)
945  {
946  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
947  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
948  "Invalid tuning param!");
949 
950  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
955  {
956  if(!(karg.M % MPerBlock == 0))
957  {
958 #if DEBUG_LOG
959  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
960  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
961  << std::endl;
962 
963 #endif // DEBUG_LOG
964  return false;
965  }
966  }
967 
968  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
973  {
974  if(!(karg.N % NPerBlock == 0))
975  {
976 #if DEBUG_LOG
977  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
978  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
979  << std::endl;
980 
981 #endif // DEBUG_LOG
982  return false;
983  }
984  }
985 
986  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
990  {
991 
992  auto K_t = karg.KBatch * KPerBlock;
993  if(!(karg.K % K_t == 0))
994  {
995 #if DEBUG_LOG
996  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
997  << karg.K << " " << __FILE__ << ":" << __LINE__
998  << ", in function: " << __func__ << std::endl;
999 
1000 #endif // DEBUG_LOG
1001  return false;
1002  }
1003  }
1004  else
1005  {
1006  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1007  auto K_t = karg.KBatch * KReadVec;
1008  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1009  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1010  {
1011  return false;
1012  }
1013  }
1014 
1016  {
1017  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1018  {
1019 #if DEBUG_LOG
1020  std::cout << "Arg K (" << karg.K
1021  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1022  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1023  << __LINE__ << ", in function: " << __func__ << std::endl;
1024 
1025 #endif // DEBUG_LOG
1026  return false;
1027  }
1028  }
1029  else
1030  {
1031  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1032  {
1033 #if DEBUG_LOG
1034  std::cout << "Arg M (" << karg.M
1035  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1036  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1037  << __LINE__ << ", in function: " << __func__ << std::endl;
1038 
1039 #endif // DEBUG_LOG
1040  return false;
1041  }
1042  }
1043 
1045  {
1046  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1047  {
1048 #if DEBUG_LOG
1049  std::cout << "Arg N (" << karg.N
1050  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1051  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1052  << __LINE__ << ", in function: " << __func__ << std::endl;
1053 
1054 #endif // DEBUG_LOG
1055  return false;
1056  }
1057  }
1058  else
1059  {
1060  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1061  {
1062 #if DEBUG_LOG
1063  std::cout << "Arg K (" << karg.K
1064  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1065  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1066  << __LINE__ << ", in function: " << __func__ << std::endl;
1067 
1068 #endif // DEBUG_LOG
1069  return false;
1070  }
1071  }
1072 
1074  {
1076  {
1077 #if DEBUG_LOG
1078  std::cout << "Arg N (" << karg.N
1079  << ") value is not a multiple of "
1080  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1081  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1082  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1083 
1084 #endif // DEBUG_LOG
1085  return false;
1086  }
1087  }
1088  else
1089  {
1091  {
1092 #if DEBUG_LOG
1093  std::cout << "Arg M (" << karg.M
1094  << ") value is not a multiple of "
1095  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1096  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1097  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1098 
1099 #endif // DEBUG_LOG
1100  return false;
1101  }
1102  }
1103 
1104  // check gridwise gemm pipeline
1105 #if 0
1106  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1107 
1108  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1109  {
1110  return false;
1111  }
1112 #endif
1113  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1114  return true;
1115  }
1116 
1117  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1118  {
1119  const index_t num_loop = K / KPerBlock;
1120 
1121  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1122  }
1123 
1124  __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1125  {
1126  const index_t num_loop = K / KPerBlock;
1127 
1128  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1129  }
1130 
1131  template <typename CGridDesc>
1133  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1134  {
1135  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1136  c_grid_desc_m_n,
1141 
1142  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1143  }
1144 
1145  // return block_id to C matrix tile idx (m0, n0) mapping
1146  // if arch = gfx942
1147  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1148  // NPerBlock>;
1149 
1150  template <bool HasMainKBlockLoop,
1151  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1152  TailNumber TailNum = TailNumber::Odd>
1153  __device__ static void Run(const index_t* p_sorted_token_ids,
1154  const index_t* p_sorted_expert_ids,
1155  const index_t* p_max_token_id,
1156  const ADataType* p_a_grid,
1157  const BDataType* p_b_grid,
1158  DsGridPointer& p_ds_grid,
1159  CDataType* p_c_grid,
1160  void* p_shared,
1161  const Problem& problem,
1162  AElementwiseOperation a_element_op,
1163  BElementwiseOperation b_element_op,
1164  CElementwiseOperation c_element_op)
1165  {
1166  ignore = b_element_op;
1167  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1168  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1169  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1170  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1171  problem.MPadded,
1172  problem.K,
1173  problem.KPadded,
1174  problem.StrideA,
1175  problem.AK0);
1176  const auto b_grid_desc_bpreshuffled =
1177  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1178  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1179  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1180  problem.MPadded,
1181  problem.N,
1182  problem.NPadded,
1183  problem.StrideC);
1184  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1186  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1187  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1188  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1189  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1190  if(expert_block_id * MPerBlock >= max_token_id)
1191  return;
1192  const index_t expert_id =
1193  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1194  const auto block_mn = [&]() -> std::pair<int, int> {
1195  if constexpr(NSwizzle)
1196  {
1197  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1198  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1199  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1200  const index_t expert_swizzle =
1201  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1202  const index_t bid_new = blockIdx.x - prefix_block;
1203  const index_t nid = __builtin_amdgcn_readfirstlane(
1204  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1205  const index_t mid =
1206  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1207  return {nid, mid};
1208  }
1209  else
1210  {
1211  return {blockIdx.x, blockIdx.y};
1212  }
1213  }();
1214 
1215  const index_t block_n_id = block_mn.first;
1216  const index_t block_m_id = block_mn.second;
1217  const index_t token0 =
1218  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1219 
1220  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1221  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1222  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1223  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1224  constexpr auto AKThreads = AK0Threads * AK1Threads;
1225  constexpr auto AMRepeats = MPerBlock / AMThreads;
1226  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1227 
1228  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1229  return;
1231  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1232  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1233  index_t token_offset = fused_token & 0xffffff;
1234  if constexpr(!IsInputGemm)
1235  {
1236  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1237  }
1238  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1239  });
1240  const IndexType expert_stride =
1241  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1242  const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
1243  // N0, K0, Blocksize*KPack
1244  const index_t n_block_data_idx_on_grid =
1245  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1246 
1247  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1248  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1249  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1250  p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1251  // A matrix in LDS memory, dst of blockwise copy
1252  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1253 
1254  // B matrix in LDS memory, dst of blockwise copy
1255  // dummy
1256  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1257  // A matrix blockwise copy
1258  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1260  AElementwiseOperation,
1264  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1265  ABlockTransferThreadClusterArrangeOrder,
1266  ADataType,
1267  LDSTypeA,
1268  decltype(a_grid_desc_ak0_m_ak1),
1269  decltype(a_block_desc_ak0_m_ak1),
1270  ABlockTransferSrcAccessOrder,
1272  ABlockTransferSrcVectorDim,
1273  2,
1274  ABlockTransferSrcScalarPerVector,
1275  ABlockTransferDstScalarPerVector_AK1,
1276  1,
1277  1,
1278  AThreadTransferSrcResetCoordinateAfterRun,
1279  true,
1280  IndexType,
1281  1,
1282  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1283  make_multi_index(0, 0, 0),
1284  a_element_op,
1285  a_block_desc_ak0_m_ak1,
1286  make_multi_index(0, 0, 0),
1288  gather_offsets);
1289 
1290  // Thread-wise copy
1291  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1292  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1293  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1294 
1295  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1296  BDataType,
1297  BDataType,
1298  decltype(b_grid_desc_bpreshuffled),
1299  decltype(b_block_desc_bk0_n_bk1),
1302  3,
1303  BBlockTransferSrcScalarPerVector,
1304  BThreadTransferSrcResetCoordinateAfterRun,
1305  true>(b_grid_desc_bpreshuffled,
1306  make_multi_index(n_block_data_idx_on_grid,
1308  0,
1309  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1310 
1311  // LDS allocation for A and B: be careful of alignment
1312  // Cast after lds
1313  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1314  static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1315 
1316  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1317  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1318 
1319  // Blockwise GEMM pipeline
1320  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1321  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1322  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1323  decltype(c_thread_buf) c_thread_buf_up;
1324 
1326  float,
1327  c_thread_buf.num_of_v_,
1328  c_thread_buf.s_per_v,
1329  true>
1330  c_thread_buf_fp32;
1331 
1332  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1333  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1334  KPerBlock);
1335  if constexpr(IsInputGemm)
1336  {
1337  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1338  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1339  p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1340  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1341  BDataType,
1342  BDataType,
1343  decltype(b_grid_desc_bpreshuffled),
1344  decltype(b_block_desc_bk0_n_bk1),
1347  3,
1348  BBlockTransferSrcScalarPerVector,
1349  BThreadTransferSrcResetCoordinateAfterRun,
1350  true>(b_grid_desc_bpreshuffled,
1351  make_multi_index(n_block_data_idx_on_grid,
1353  0,
1354  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1355 
1356  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1357  a_grid_desc_ak0_m_ak1,
1358  a_block_desc_ak0_m_ak1,
1359  a_blockwise_copy,
1360  a_grid_buf,
1361  a_block_buf,
1362  a_block_slice_copy_step,
1363  b_grid_desc_bpreshuffled,
1364  b_blockwise_copy,
1365  b_blockwise_copy_up,
1366  b_grid_buf,
1367  b_grid_buf_up,
1368  b_block_buf,
1369  b_block_slice_copy_step,
1370  c_thread_buf,
1371  c_thread_buf_up,
1372  num_k_block_main_loop);
1373  }
1374  else
1375  {
1376  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1377  a_grid_desc_ak0_m_ak1,
1378  a_block_desc_ak0_m_ak1,
1379  a_blockwise_copy,
1380  a_grid_buf,
1381  a_block_buf,
1382  a_block_slice_copy_step,
1383  b_grid_desc_bpreshuffled,
1384  b_blockwise_copy,
1385  b_grid_buf,
1386  b_block_buf,
1387  b_block_slice_copy_step,
1388  c_thread_buf,
1389  num_k_block_main_loop);
1390  }
1391 
1392  // shuffle C and write out
1393  {
1394  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1395  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1396  "wrong!");
1397 
1398  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1399 
1400  // TODO: hacky, fix it!
1401  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1402  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1403 
1404  // TODO: hacky, fix it!
1405  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1406  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1407  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1408 
1409  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1410  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1411  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1412  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1413  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1414  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1415  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1416  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1417 
1418  // mul scales
1419  const float* p_sorted_weights_0 = p_ds_grid[I0];
1420  const float* p_scale_b = p_ds_grid[I1];
1421 
1422  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
1423  static_assert(M4 == 4 || M4 == 8);
1424  const index_t m1 = get_warp_local_1d_id() / NWave;
1425  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
1426 
1427  if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
1428  {
1429  if constexpr(PerTokenQuant)
1430  {
1431  constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
1432  p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
1433  get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
1434  }
1435  else
1436  {
1437  p_scale_b += expert_id;
1438  }
1439 
1440  vector_type<int32_t, M4> scale_token_ids;
1441  vector_type<float, M4> topk_weights;
1442  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1443  const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
1444  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1445  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1446  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1447  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1448  if constexpr(PerTokenQuant)
1449  {
1450  scale_token_ids =
1451  *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
1452  p_sorted_token_ids + m_pos);
1453  }
1454  if constexpr(MulRoutedWeight)
1455  {
1456  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1457  p_ds_grid[I2] + m_pos);
1458  }
1459  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1460  float scale_a = [&]() {
1461  if constexpr(PerTokenQuant)
1462  {
1463  index_t fused_token =
1464  scale_token_ids.template AsType<index_t>()[m4];
1465  const index_t token_offset = fused_token & 0xffffff;
1466  return token_offset < problem.NumTokens
1467  ? p_sorted_weights_0[IsInputGemm
1468  ? token_offset
1469  : token_offset *
1470  problem.TopK +
1471  (fused_token >>
1472  24)]
1473  : 0.0;
1474  }
1475  else
1476  {
1477  return p_sorted_weights_0[0];
1478  }
1479  }();
1480  constexpr index_t c_offset =
1481  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1482  make_tuple(m0, n0, m2 * M4 + m4));
1483  constexpr auto cidx = Number<c_offset>{};
1484  if constexpr(IsInputGemm) // gu fusion
1485  {
1486  if constexpr(ActivationOperation == Activation::silu_and_mul)
1487  {
1488  const float scale_up =
1489  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
1490  PerTokenQuant];
1491  float gate = scale_a * scale_b * c_thread_buf[cidx];
1492  float up = scale_a * scale_up * c_thread_buf_up[cidx];
1493  if constexpr(MulRoutedWeight)
1494  {
1495  gate = gate * topk_weights.template AsType<float>()[m4];
1496  up = up * topk_weights.template AsType<float>()[m4];
1497  }
1499  {
1500  gate *= 16;
1501  up *= 16;
1502  }
1504  c_thread_buf_fp32(cidx) = gate * up;
1505  }
1506  else if(ActivationOperation == Activation::gelu_and_mul)
1507  {
1508  const float scale_up =
1509  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
1510  PerTokenQuant];
1511  float gate = scale_a * scale_b * c_thread_buf[cidx];
1512  float up = scale_a * scale_up * c_thread_buf_up[cidx];
1513  if constexpr(MulRoutedWeight)
1514  {
1515  gate = gate * topk_weights.template AsType<float>()[m4];
1516  up = up * topk_weights.template AsType<float>()[m4];
1517  }
1519  {
1520  gate *= 16;
1521  up *= 16;
1522  }
1524  c_thread_buf_fp32(cidx) = gate * up;
1525  }
1526  }
1527  else
1528  {
1529  c_thread_buf_fp32(cidx) =
1530  scale_a * scale_b * c_thread_buf[cidx];
1531  if constexpr(MulRoutedWeight)
1532  {
1533  c_thread_buf_fp32(cidx) =
1534  c_thread_buf_fp32(cidx) *
1535  topk_weights.template AsType<float>()[m4];
1536  }
1537  }
1538  });
1539  });
1540  });
1541  });
1542  }
1543  else
1544  {
1545  vector_type<float, M4> topk_weights; // for gemm2 only
1546  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1547  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1548  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1549  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1550  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1551  if constexpr(MulRoutedWeight)
1552  {
1553  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1554  p_ds_grid[I2] + m_pos);
1555  }
1556  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1557  constexpr index_t c_offset =
1558  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1559  make_tuple(m0, n0, m2 * M4 + m4));
1560  constexpr auto cidx = Number<c_offset>{};
1561 
1562  if constexpr(IsInputGemm) // gu fusion
1563  {
1564  if constexpr(ActivationOperation == Activation::silu_and_mul)
1565  {
1566  float gate = c_thread_buf[cidx];
1567  float up = c_thread_buf_up[cidx];
1568  if constexpr(MulRoutedWeight)
1569  {
1570  gate = gate * topk_weights.template AsType<float>()[m4];
1571  up = up * topk_weights.template AsType<float>()[m4];
1572  }
1574  c_thread_buf_fp32(cidx) = gate * up;
1575  }
1576  else if(ActivationOperation == Activation::gelu_and_mul)
1577  {
1578  float gate = c_thread_buf[cidx];
1579  float up = c_thread_buf_up[cidx];
1580  if constexpr(MulRoutedWeight)
1581  {
1582  gate = gate * topk_weights.template AsType<float>()[m4];
1583  up = up * topk_weights.template AsType<float>()[m4];
1584  }
1586  c_thread_buf_fp32(cidx) = gate * up;
1587  }
1588  }
1589  else
1590  {
1591  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1592  if constexpr(MulRoutedWeight)
1593  {
1594  c_thread_buf_fp32(cidx) =
1595  topk_weights.template AsType<float>()[m4] *
1596  c_thread_buf_fp32[cidx];
1597  }
1598  }
1599  });
1600  });
1601  });
1602  });
1603  }
1604 
1605  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1607 
1608  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1609  static_cast<CShuffleDataType*>(p_shared),
1610  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1611 
1612  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1613  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1614  make_tuple(
1617  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1618  M1, // M1 = MWave
1619  M2, // M2 * M3 * M4 = MPerXdl
1620  M3,
1621  M4)),
1624  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1625  N1, // N1 = NWave
1626  N2))), // N2 = NPerXdl
1628  make_tuple(
1630 
1631  // calculate origin of thread output tensor on global memory
1632  // blockwise GEMM c matrix starting index
1633  const auto c_thread_mtx_on_block =
1634  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1635 
1636  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1637  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1638 
1639  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1641  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1643  make_tuple(Sequence<0>{}));
1644 
1645  const auto m_thread_data_on_block_idx =
1646  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1647  make_multi_index(m_thread_data_on_block));
1648 
1649  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1653  make_tuple(Sequence<0>{}));
1654 
1655  const auto n_thread_data_on_block_idx =
1656  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1657  make_multi_index(n_thread_data_on_block));
1658 
1659  // shuffle: threadwise copy C from VGPR to LDS
1660  auto c_thread_copy_vgpr_to_lds =
1662  CShuffleDataType,
1663  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1664  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1666  Sequence<CShuffleMXdlPerWavePerShuffle,
1667  CShuffleNXdlPerWavePerShuffle,
1668  I1,
1669  I1,
1670  M2,
1671  I1,
1672  M4,
1673  I1>,
1675  7,
1676  1,
1678  1,
1679  true>{
1680  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1681  make_multi_index(0,
1682  0,
1683  m_thread_data_on_block_idx[I1],
1684  n_thread_data_on_block_idx[I1],
1685  m_thread_data_on_block_idx[I2],
1686  m_thread_data_on_block_idx[I3],
1687  m_thread_data_on_block_idx[I4],
1688  n_thread_data_on_block_idx[I2]),
1690 
1691  using EDataType = CDataType;
1692 
1693  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1694  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1695 
1696  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1698  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1699 
1700  const auto ds_grid_buf = generate_tuple(
1701  [&](auto i) {
1702  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1703  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1704  },
1705  Number<NumDTensor>{});
1706 
1707  // tuple of reference to C/Ds tensor descriptors
1708  const auto c_ds_desc_refs = concat_tuple_of_reference(
1709  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1710  generate_tie([&](auto i) -> const auto& // return type should be reference
1711  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1712  Number<NumDTensor>{}));
1713 
1714  // tuple of reference to C/Ds tensor descriptors
1715  const auto c_ds_buf_refs = concat_tuple_of_reference(
1716  tie(c_shuffle_block_buf),
1717  generate_tie([&](auto i) -> const auto& // return type should be reference
1718  { return ds_grid_buf[i]; },
1719  Number<NumDTensor>{}));
1720 
1721  // tuple of starting index of C/Ds blockwise copy
1722  const auto idx_c_ds_block_begin =
1725  [&](auto) {
1726  return make_multi_index(block_m_id, 0, block_n_id, 0);
1727  // return make_multi_index(block_work_idx[I0], 0,
1728  // block_work_idx[I1], 0);
1729  },
1730  Number<NumDTensor>{}));
1731 
1732  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1733  c_grid_desc_mblock_mperblock_nblock_nperblock;
1734 
1735  using CDEBlockTransferCluster =
1736  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1737  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1738  constexpr index_t scatter_weight_idx = 3; // hack fix felix
1739  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1741  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1743  decltype(c_ds_desc_refs),
1744  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1745  CElementwiseOperation,
1746  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1747  // support arbitray type
1748  Sequence<1,
1749  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1750  1,
1751  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1752  CDEBlockTransferCluster,
1753  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1754  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1755  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1756  3, // index_t SrcVectorDim,
1757  3, // index_t DstVectorDim,
1758  CDEShuffleBlockTransferScalarPerVectors,
1763  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1764  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1765  IndexType,
1766  1, // ScatterDim
1767  true, // OutputScatter: false, only use scatter weights
1768  scatter_weight_idx // ScatterWeightIdx: ascale
1769  >{c_ds_desc_refs,
1770  idx_c_ds_block_begin,
1771  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1772  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1773  c_element_op};
1774 
1775  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1776  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1777  constexpr auto sfc_c_vgpr =
1780  Sequence<CShuffleMXdlPerWavePerShuffle,
1781  CShuffleNXdlPerWavePerShuffle,
1782  1,
1783  1,
1784  M2,
1785  1,
1786  M4,
1787  1>>{};
1788 
1789  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1790 
1791  // space filling curve for shuffled blockwise C/D/E
1792  constexpr auto sfc_cde_block =
1795  Sequence<1,
1796  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1797  1,
1798  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1799 
1800  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1801  constexpr auto EMThreads =
1802  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1803  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1804  constexpr auto ENThreads =
1805  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1806  static_for<0, num_access, 1>{}([&](auto access_id) {
1807  // make sure it's safe to write to LDS
1809 
1810  auto dstidx = sfc_cde_block.GetIndex(access_id);
1811  const index_t c_token_pos =
1812  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1813  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1814  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1815  IndexType token_offset = fused_token & 0xffffff;
1816  if constexpr(IsInputGemm)
1817  {
1818  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1819  }
1820  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
1821  });
1822 
1823  block_sync_lds();
1824 
1825  // each thread write its data from VGPR to LDS
1826  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1827  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1828  c_thread_buf_fp32,
1829  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1830  c_shuffle_block_buf);
1831 
1832  // make sure it's safe to read from LDS
1833  block_sync_lds();
1834 
1835  // each block copy its data from LDS to global
1836  cde_block_copy_lds_and_global.Run(
1837  c_ds_desc_refs,
1838  c_ds_buf_refs,
1839  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1840  tie(c_grid_buf),
1841  scatter_offsets);
1842 
1843  if constexpr(access_id < num_access - 1)
1844  {
1845  constexpr auto cde_lds_and_global_step =
1846  sfc_cde_block.GetForwardStep(access_id);
1847 
1848  // move on Ds
1849  static_for<0, NumDTensor, 1>{}([&](auto i) {
1850  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1851  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1852  });
1853 
1854  // move on E
1855  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1856  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1857  I0,
1858  cde_lds_and_global_step);
1859  }
1860  });
1861  }
1862  }
1863 
1864  template <bool HasMainKBlockLoop,
1865  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1866  TailNumber TailNum = TailNumber::Odd>
1867  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1868  const index_t* p_sorted_expert_ids,
1869  const index_t* p_max_token_id,
1870  const ADataType* p_a_grid,
1871  const BDataType* p_b_grid,
1872  DsGridPointer& p_ds_grid,
1873  CDataType* p_c_grid,
1874  void* p_shared,
1875  void* p_shared1,
1876  const Problem& problem,
1877  AElementwiseOperation a_element_op,
1878  BElementwiseOperation b_element_op,
1879  CElementwiseOperation c_element_op)
1880  {
1881  ignore = b_element_op;
1882  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1883  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1884  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1885  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1886  problem.MPadded,
1887  problem.K,
1888  problem.KPadded,
1889  problem.StrideA,
1890  problem.AK0);
1891  const auto b_grid_desc_bpreshuffled =
1892  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1893  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1894  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1895  problem.MPadded,
1896  problem.N,
1897  problem.NPadded,
1898  problem.StrideC);
1899  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1901  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1902  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1903  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1904  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1905  if(expert_block_id * MPerBlock >= max_token_id)
1906  return;
1907  const index_t expert_id =
1908  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1909  const auto block_mn = [&]() -> std::pair<int, int> {
1910  if constexpr(NSwizzle)
1911  {
1912  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1913  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1914  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1915  const index_t expert_swizzle =
1916  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1917  const index_t bid_new = blockIdx.x - prefix_block;
1918  const index_t nid = __builtin_amdgcn_readfirstlane(
1919  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1920  const index_t mid =
1921  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1922  return {nid, mid};
1923  }
1924  else
1925  {
1926  return {blockIdx.x, blockIdx.y};
1927  }
1928  }();
1929 
1930  const index_t block_n_id = block_mn.first;
1931  const index_t block_m_id = block_mn.second;
1932  const index_t token0 =
1933  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1934 
1935  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1936  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1937  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1938  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1939  constexpr auto AKThreads = AK0Threads * AK1Threads;
1940  constexpr auto AMRepeats = MPerBlock / AMThreads;
1941  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1942 
1943  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1944  return;
1946  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1947  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1948  index_t token_offset = fused_token & 0xffffff;
1949  if constexpr(!IsInputGemm)
1950  {
1951  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1952  }
1953  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1954  });
1955  const IndexType expert_stride =
1956  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1957  const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
1958  // N0, K0, Blocksize*KPack
1959  const index_t n_block_data_idx_on_grid =
1960  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1961 
1962  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1963  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1964  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1965  p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1966 
1967  // A matrix in LDS memory, dst of blockwise copy
1968  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1969 
1970  // B matrix in LDS memory, dst of blockwise copy
1971  // dummy
1972  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1973  // A matrix blockwise copy
1974  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1976  AElementwiseOperation,
1980  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1981  ABlockTransferThreadClusterArrangeOrder,
1982  ADataType,
1983  LDSTypeA,
1984  decltype(a_grid_desc_ak0_m_ak1),
1985  decltype(a_block_desc_ak0_m_ak1),
1986  ABlockTransferSrcAccessOrder,
1988  ABlockTransferSrcVectorDim,
1989  2,
1990  ABlockTransferSrcScalarPerVector,
1991  ABlockTransferDstScalarPerVector_AK1,
1992  1,
1993  1,
1994  AThreadTransferSrcResetCoordinateAfterRun,
1995  true,
1996  IndexType,
1997  1,
1998  2>(a_grid_desc_ak0_m_ak1,
1999  make_multi_index(0, 0, 0),
2000  a_element_op,
2001  a_block_desc_ak0_m_ak1,
2002  make_multi_index(0, 0, 0),
2004  gather_offsets);
2005 
2006  // Thread-wise copy
2007  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2008  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2009  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2010  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2011  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2012  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2013 
2014  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
2015  BDataType,
2016  BDataType,
2017  decltype(b_grid_desc_bpreshuffled),
2018  decltype(b_block_desc_bk0_n_bk1),
2021  3,
2022  BBlockTransferSrcScalarPerVector,
2023  BThreadTransferSrcResetCoordinateAfterRun,
2024  true>(b_grid_desc_bpreshuffled,
2025  make_multi_index(n_block_data_idx_on_grid,
2027  0,
2028  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2029 
2030  // LDS allocation for A and B: be careful of alignment
2031  // Cast after lds
2032  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2033  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2034  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2035  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2036  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2037 
2038  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2039  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
2040 
2041  // Blockwise GEMM pipeline
2042  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2043  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2044  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2045  decltype(c_thread_buf) c_thread_buf_up;
2046 
2048  float,
2049  c_thread_buf.num_of_v_,
2050  c_thread_buf.s_per_v,
2051  true>
2052  c_thread_buf_fp32;
2053 
2054  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2055  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2056  KPerBlock);
2057 
2058  if constexpr(IsInputGemm)
2059  {
2060  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2061  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2062  p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2063  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2064  BDataType,
2065  BDataType,
2066  decltype(b_grid_desc_bpreshuffled),
2067  decltype(b_block_desc_bk0_n_bk1),
2070  3,
2071  BBlockTransferSrcScalarPerVector,
2072  BThreadTransferSrcResetCoordinateAfterRun,
2073  true>(b_grid_desc_bpreshuffled,
2074  make_multi_index(n_block_data_idx_on_grid,
2076  0,
2077  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2078  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2079  a_grid_desc_ak0_m_ak1,
2080  a_block_desc_ak0_m_ak1,
2081  a_blockwise_copy,
2082  a_grid_buf,
2083  a_block_bufs,
2084  a_block_slice_copy_step,
2085  b_grid_desc_bpreshuffled,
2086  b_blockwise_copy,
2087  b_blockwise_copy_up,
2088  b_grid_buf,
2089  b_grid_buf_up,
2090  b_block_bufs,
2091  b_block_slice_copy_step,
2092  c_thread_buf,
2093  c_thread_buf_up,
2094  num_k_block_main_loop);
2095  }
2096  else
2097  {
2098 
2099  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2100  a_grid_desc_ak0_m_ak1,
2101  a_block_desc_ak0_m_ak1,
2102  a_blockwise_copy,
2103  a_grid_buf,
2104  a_block_bufs,
2105  a_block_slice_copy_step,
2106  b_grid_desc_bpreshuffled,
2107  b_blockwise_copy,
2108  b_grid_buf,
2109  b_block_bufs,
2110  b_block_slice_copy_step,
2111  c_thread_buf,
2112  num_k_block_main_loop);
2113  }
2114 
2115  // shuffle C and write out
2116  {
2117  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2118  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2119  "wrong!");
2120 
2121  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2122 
2123  // TODO: hacky, fix it!
2124  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2125  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2126 
2127  // TODO: hacky, fix it!
2128  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2129  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2130  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2131 
2132  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2133  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2134  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2135  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2136  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2137  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2138  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2139  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2140 
2141  // mul scales
2142  const float* p_sorted_weights_0 = p_ds_grid[I0];
2143  const float* p_scale_b = p_ds_grid[I1];
2144 
2145  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2146  static_assert(M4 == 4 || M4 == 8);
2147  const index_t m1 = get_warp_local_1d_id() / NWave;
2148  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
2149 
2150  if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
2151  {
2152  if constexpr(PerTokenQuant)
2153  {
2154  constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
2155  p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
2156  get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
2157  }
2158  else
2159  {
2160  p_scale_b += expert_id;
2161  }
2162 
2163  vector_type<int32_t, M4> scale_token_ids;
2164  vector_type<float, M4> topk_weights;
2165  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2166  const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
2167  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2168  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2169  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2170  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2171  if constexpr(PerTokenQuant)
2172  {
2173  scale_token_ids =
2174  *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
2175  p_sorted_token_ids + m_pos);
2176  }
2177  if constexpr(MulRoutedWeight)
2178  {
2179  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2180  p_ds_grid[I2] + m_pos);
2181  }
2182  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2183  float scale_a = [&]() {
2184  if constexpr(PerTokenQuant)
2185  {
2186  index_t fused_token =
2187  scale_token_ids.template AsType<index_t>()[m4];
2188  const index_t token_offset = fused_token & 0xffffff;
2189  return token_offset < problem.NumTokens
2190  ? p_sorted_weights_0[IsInputGemm
2191  ? token_offset
2192  : token_offset *
2193  problem.TopK +
2194  (fused_token >>
2195  24)]
2196  : 0.0;
2197  }
2198  else
2199  {
2200  return p_sorted_weights_0[0];
2201  }
2202  }();
2203  constexpr index_t c_offset =
2204  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2205  make_tuple(m0, n0, m2 * M4 + m4));
2206  constexpr auto cidx = Number<c_offset>{};
2207  if constexpr(IsInputGemm) // gu fusion
2208  {
2209  if constexpr(ActivationOperation == Activation::silu_and_mul)
2210  {
2211  const float scale_up =
2212  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
2213  PerTokenQuant];
2214  float gate = scale_a * scale_b * c_thread_buf[cidx];
2215  float up = scale_a * scale_up * c_thread_buf_up[cidx];
2216  if constexpr(MulRoutedWeight)
2217  {
2218  gate = gate * topk_weights.template AsType<float>()[m4];
2219  up = up * topk_weights.template AsType<float>()[m4];
2220  }
2222  {
2223  gate *= 16;
2224  up *= 16;
2225  }
2227  c_thread_buf_fp32(cidx) = gate * up;
2228  }
2229  else if(ActivationOperation == Activation::gelu_and_mul)
2230  {
2231  const float scale_up =
2232  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
2233  PerTokenQuant];
2234  float gate = scale_a * scale_b * c_thread_buf[cidx];
2235  float up = scale_a * scale_up * c_thread_buf_up[cidx];
2236  if constexpr(MulRoutedWeight)
2237  {
2238  gate = gate * topk_weights.template AsType<float>()[m4];
2239  up = up * topk_weights.template AsType<float>()[m4];
2240  }
2242  {
2243  gate *= 16;
2244  up *= 16;
2245  }
2247  c_thread_buf_fp32(cidx) = gate * up;
2248  }
2249  }
2250  else
2251  {
2252  c_thread_buf_fp32(cidx) =
2253  scale_a * scale_b * c_thread_buf[cidx];
2254  if constexpr(MulRoutedWeight)
2255  {
2256  c_thread_buf_fp32(cidx) =
2257  c_thread_buf_fp32(cidx) *
2258  topk_weights.template AsType<float>()[m4];
2259  }
2260  }
2261  });
2262  });
2263  });
2264  });
2265  }
2266  else
2267  {
2268  vector_type<float, M4> topk_weights; // for gemm2 only
2269  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2270  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2271  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2272  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2273  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2274  if constexpr(MulRoutedWeight)
2275  {
2276  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2277  p_ds_grid[I2] + m_pos);
2278  }
2279  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2280  constexpr index_t c_offset =
2281  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2282  make_tuple(m0, n0, m2 * M4 + m4));
2283  constexpr auto cidx = Number<c_offset>{};
2284 
2285  if constexpr(IsInputGemm) // gu fusion
2286  {
2287  if constexpr(ActivationOperation == Activation::silu_and_mul)
2288  {
2289  float gate = c_thread_buf[cidx];
2290  float up = c_thread_buf_up[cidx];
2291  if constexpr(MulRoutedWeight)
2292  {
2293  gate = gate * topk_weights.template AsType<float>()[m4];
2294  up = up * topk_weights.template AsType<float>()[m4];
2295  }
2297  c_thread_buf_fp32(cidx) = gate * up;
2298  }
2299  else if(ActivationOperation == Activation::gelu_and_mul)
2300  {
2301  float gate = c_thread_buf[cidx];
2302  float up = c_thread_buf_up[cidx];
2303  if constexpr(MulRoutedWeight)
2304  {
2305  gate = gate * topk_weights.template AsType<float>()[m4];
2306  up = up * topk_weights.template AsType<float>()[m4];
2307  }
2309  c_thread_buf_fp32(cidx) = gate * up;
2310  }
2311  }
2312  else
2313  {
2314  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2315  if constexpr(MulRoutedWeight)
2316  {
2317  c_thread_buf_fp32(cidx) =
2318  topk_weights.template AsType<float>()[m4] *
2319  c_thread_buf_fp32[cidx];
2320  }
2321  }
2322  });
2323  });
2324  });
2325  });
2326  }
2327 
2328  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2330 
2331  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2332  static_cast<CShuffleDataType*>(p_shared),
2333  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2334 
2335  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2336  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2337  make_tuple(
2340  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2341  M1, // M1 = MWave
2342  M2, // M2 * M3 * M4 = MPerXdl
2343  M3,
2344  M4)),
2347  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2348  N1, // N1 = NWave
2349  N2))), // N2 = NPerXdl
2351  make_tuple(
2353 
2354  // calculate origin of thread output tensor on global memory
2355  // blockwise GEMM c matrix starting index
2356  const auto c_thread_mtx_on_block =
2357  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2358 
2359  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2360  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2361 
2362  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2364  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2366  make_tuple(Sequence<0>{}));
2367 
2368  const auto m_thread_data_on_block_idx =
2369  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2370  make_multi_index(m_thread_data_on_block));
2371 
2372  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2376  make_tuple(Sequence<0>{}));
2377 
2378  const auto n_thread_data_on_block_idx =
2379  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2380  make_multi_index(n_thread_data_on_block));
2381 
2382  // shuffle: threadwise copy C from VGPR to LDS
2383  auto c_thread_copy_vgpr_to_lds =
2385  CShuffleDataType,
2386  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2387  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2389  Sequence<CShuffleMXdlPerWavePerShuffle,
2390  CShuffleNXdlPerWavePerShuffle,
2391  I1,
2392  I1,
2393  M2,
2394  I1,
2395  M4,
2396  I1>,
2398  7,
2399  1,
2401  1,
2402  true>{
2403  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2404  make_multi_index(0,
2405  0,
2406  m_thread_data_on_block_idx[I1],
2407  n_thread_data_on_block_idx[I1],
2408  m_thread_data_on_block_idx[I2],
2409  m_thread_data_on_block_idx[I3],
2410  m_thread_data_on_block_idx[I4],
2411  n_thread_data_on_block_idx[I2]),
2413 
2414  using EDataType = CDataType;
2415 
2416  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2417  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2418 
2419  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2421  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2422 
2423  const auto ds_grid_buf = generate_tuple(
2424  [&](auto i) {
2425  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2426  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2427  },
2428  Number<NumDTensor>{});
2429 
2430  // tuple of reference to C/Ds tensor descriptors
2431  const auto c_ds_desc_refs = concat_tuple_of_reference(
2432  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2433  generate_tie([&](auto i) -> const auto& // return type should be reference
2434  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2435  Number<NumDTensor>{}));
2436 
2437  // tuple of reference to C/Ds tensor descriptors
2438  const auto c_ds_buf_refs = concat_tuple_of_reference(
2439  tie(c_shuffle_block_buf),
2440  generate_tie([&](auto i) -> const auto& // return type should be reference
2441  { return ds_grid_buf[i]; },
2442  Number<NumDTensor>{}));
2443 
2444  // tuple of starting index of C/Ds blockwise copy
2445  const auto idx_c_ds_block_begin =
2448  [&](auto) {
2449  return make_multi_index(block_m_id, 0, block_n_id, 0);
2450  // return make_multi_index(block_work_idx[I0], 0,
2451  // block_work_idx[I1], 0);
2452  },
2453  Number<NumDTensor>{}));
2454 
2455  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2456  c_grid_desc_mblock_mperblock_nblock_nperblock;
2457 
2458  using CDEBlockTransferCluster =
2459  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2460  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2461  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2462  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2464  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2466  decltype(c_ds_desc_refs),
2467  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2468  CElementwiseOperation,
2469  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2470  // support arbitray type
2471  Sequence<1,
2472  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2473  1,
2474  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2475  CDEBlockTransferCluster,
2476  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2477  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2478  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2479  3, // index_t SrcVectorDim,
2480  3, // index_t DstVectorDim,
2481  CDEShuffleBlockTransferScalarPerVectors,
2486  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2487  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2488  IndexType,
2489  1, // ScatterDim
2490  true, // OutputScatter: false, only use scatter weights
2491  scatter_weight_idx // ScatterWeightIdx: ascale
2492  >{c_ds_desc_refs,
2493  idx_c_ds_block_begin,
2494  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2495  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2496  c_element_op};
2497 
2498  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2499  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2500  constexpr auto sfc_c_vgpr =
2503  Sequence<CShuffleMXdlPerWavePerShuffle,
2504  CShuffleNXdlPerWavePerShuffle,
2505  1,
2506  1,
2507  M2,
2508  1,
2509  M4,
2510  1>>{};
2511 
2512  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2513 
2514  // space filling curve for shuffled blockwise C/D/E
2515  constexpr auto sfc_cde_block =
2518  Sequence<1,
2519  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2520  1,
2521  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2522 
2523  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2524  constexpr auto EMThreads =
2525  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2526  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2527  constexpr auto ENThreads =
2528  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2529  static_for<0, num_access, 1>{}([&](auto access_id) {
2530  // make sure it's safe to write to LDS
2532 
2533  auto dstidx = sfc_cde_block.GetIndex(access_id);
2534  const index_t c_token_pos =
2535  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2536  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2537  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2538  IndexType token_offset = fused_token & 0xffffff;
2539  if constexpr(IsInputGemm)
2540  {
2541  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2542  }
2543  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2544  });
2545 
2546  block_sync_lds();
2547 
2548  // each thread write its data from VGPR to LDS
2549  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2550  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2551  c_thread_buf_fp32,
2552  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2553  c_shuffle_block_buf);
2554 
2555  // make sure it's safe to read from LDS
2556  block_sync_lds();
2557 
2558  // each block copy its data from LDS to global
2559  cde_block_copy_lds_and_global.Run(
2560  c_ds_desc_refs,
2561  c_ds_buf_refs,
2562  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2563  tie(c_grid_buf),
2564  scatter_offsets);
2565 
2566  if constexpr(access_id < num_access - 1)
2567  {
2568  constexpr auto cde_lds_and_global_step =
2569  sfc_cde_block.GetForwardStep(access_id);
2570 
2571  // move on Ds
2572  static_for<0, NumDTensor, 1>{}([&](auto i) {
2573  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2574  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2575  });
2576 
2577  // move on E
2578  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2579  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2580  I0,
2581  cde_lds_and_global_step);
2582  }
2583  });
2584  }
2585  }
2586 };
2587 
2588 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:56
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr auto BlockGemmBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp:41
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
Definition: gridwise_moe_gemm.hpp:653
const BDataType * p_b_grid
Definition: gridwise_moe_gemm.hpp:709
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm.hpp:705
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm.hpp:706
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm.hpp:713
const ADataType * p_a_grid
Definition: gridwise_moe_gemm.hpp:708
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm.hpp:654
const index_t * p_max_token_id
Definition: gridwise_moe_gemm.hpp:707
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm.hpp:714
CDataType * p_c_grid
Definition: gridwise_moe_gemm.hpp:711
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm.hpp:710
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm.hpp:715
Definition: gridwise_moe_gemm.hpp:588
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm.hpp:638
index_t NumTokens
Definition: gridwise_moe_gemm.hpp:631
index_t MBlock
Definition: gridwise_moe_gemm.hpp:647
index_t TopK
Definition: gridwise_moe_gemm.hpp:632
index_t K
Definition: gridwise_moe_gemm.hpp:635
__host__ void Print() const
Definition: gridwise_moe_gemm.hpp:620
index_t NPadded
Definition: gridwise_moe_gemm.hpp:642
index_t BK0
Definition: gridwise_moe_gemm.hpp:646
index_t KRead
Definition: gridwise_moe_gemm.hpp:643
index_t MPadded
Definition: gridwise_moe_gemm.hpp:641
index_t AK0
Definition: gridwise_moe_gemm.hpp:645
index_t StrideA
Definition: gridwise_moe_gemm.hpp:636
index_t StrideC
Definition: gridwise_moe_gemm.hpp:639
index_t M
Definition: gridwise_moe_gemm.hpp:633
index_t KBatch
Definition: gridwise_moe_gemm.hpp:640
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm.hpp:589
index_t KPadded
Definition: gridwise_moe_gemm.hpp:644
index_t StrideB
Definition: gridwise_moe_gemm.hpp:637
index_t N
Definition: gridwise_moe_gemm.hpp:634
index_t NBlock
Definition: gridwise_moe_gemm.hpp:648
Definition: gridwise_moe_gemm.hpp:719
index_t a_k_split_offset
Definition: gridwise_moe_gemm.hpp:751
index_t b_k_split_offset
Definition: gridwise_moe_gemm.hpp:752
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm.hpp:720
Definition: gridwise_moe_gemm.hpp:165
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm.hpp:240
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:292
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm.hpp:211
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:286
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm.hpp:204
remove_cvref_t< decltype(BlockGemmBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm.hpp:920
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm.hpp:255
static constexpr index_t NLane
Definition: gridwise_moe_gemm.hpp:206
static constexpr auto I5
Definition: gridwise_moe_gemm.hpp:171
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm.hpp:543
static constexpr auto BK0Number
Definition: gridwise_moe_gemm.hpp:179
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm.hpp:324
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm.hpp:184
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm.hpp:1124
static constexpr auto I2
Definition: gridwise_moe_gemm.hpp:168
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm.hpp:226
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm.hpp:299
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm.hpp:224
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm.hpp:406
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm.hpp:416
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm.hpp:513
static constexpr auto I6
Definition: gridwise_moe_gemm.hpp:172
static constexpr auto I0
Definition: gridwise_moe_gemm.hpp:166
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm.hpp:209
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm.hpp:1117
static constexpr auto I1
Definition: gridwise_moe_gemm.hpp:167
static constexpr auto I4
Definition: gridwise_moe_gemm.hpp:170
static constexpr auto AK1Number
Definition: gridwise_moe_gemm.hpp:180
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:274
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm.hpp:304
static constexpr auto BK1Number
Definition: gridwise_moe_gemm.hpp:181
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm.hpp:182
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm.hpp:233
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm.hpp:519
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm.hpp:264
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm.hpp:222
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1867
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:280
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm.hpp:564
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm.hpp:944
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm.hpp:881
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm.hpp:504
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm.hpp:175
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm.hpp:250
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:1132
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm.hpp:922
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm.hpp:874
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm.hpp:260
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1153
static constexpr index_t KPack
Definition: gridwise_moe_gemm.hpp:187
static constexpr index_t NWave
Definition: gridwise_moe_gemm.hpp:207
static constexpr auto I3
Definition: gridwise_moe_gemm.hpp:169
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm.hpp:269
static constexpr auto AK0Number
Definition: gridwise_moe_gemm.hpp:178
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:576
static constexpr index_t KGroup
Definition: gridwise_moe_gemm.hpp:192
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm.hpp:310
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm.hpp:755
static constexpr index_t KLane
Definition: gridwise_moe_gemm.hpp:189
static constexpr auto I7
Definition: gridwise_moe_gemm.hpp:173
Definition: xdlops_gemm.hpp:1126
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1700
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1647
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1694
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10