/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp Source File
gridwise_moe_mx_gemm_bns.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
18 
20 
21 #define DEBUG_LOG 0
22 
23 namespace ck {
24 
25 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26 // kernel function Blockers:
27 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28 // two lds chunks.
29 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30 // buffer when we declare __shared__ inside blkgemmpipe
31 
33 {
34  gelu_and_mul = 0,
35  silu_and_mul = 1
36 };
37 
38 template <typename GridwiseGemm,
39  bool HasMainKBlockLoop,
40  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
41  index_t MinimumOccupancy = 1,
42  TailNumber TailNum = TailNumber::Even>
43 __global__ void
44 #if CK_USE_LAUNCH_BOUNDS
45 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
46 #endif
47  // __attribute__((amdgpu_waves_per_eu(1, 1)))
48  kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
49 {
50 #if defined(__gfx9__)
51  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
52  {
53  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
54 
55  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
56 
57  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
58  karg.p_sorted_token_ids,
59  karg.p_sorted_expert_ids,
60  karg.p_max_token_id,
61  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
62  karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
63  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
64  karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
65  karg.p_ds_grid,
66  karg.p_c_grid,
67  p_shared,
68  karg,
69  karg.a_element_op,
70  karg.b_element_op,
71  karg.c_element_op);
72  }
73 #else
74  ignore = karg;
75 #endif // end of if (defined(__gfx9__))
76 }
77 
78 #if 0
79 template <typename GridwiseGemm,
80  bool HasMainKBlockLoop,
81  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
82  index_t MinimumOccupancy = 1,
83  TailNumber TailNum = TailNumber::Even>
84 __global__ void
85 #if CK_USE_LAUNCH_BOUNDS
86 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
87 #endif
88  // __attribute__((amdgpu_waves_per_eu(1, 1)))
89  kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
90 {
91 #if defined(__gfx9__)
92  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
93  {
94  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
95  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
96 
97  // auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
98 
99  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
100  karg.p_sorted_token_ids,
101  karg.p_sorted_expert_ids,
102  karg.p_max_token_id,
103  karg.p_a_grid,
104  karg.p_a_scale_grid,
105  karg.p_b_grid,
106  karg.p_b_scale_grid,
107  karg.p_ds_grid,
108  karg.p_c_grid,
109  p_shared,
110  p_shared1,
111  karg,
112  karg.a_element_op,
113  karg.b_element_op,
114  karg.c_element_op);
115  }
116 #else
117  ignore = karg;
118 #endif // end of if (defined(__gfx9__))
119 }
120 #endif
121 
122 template <typename ALayout,
123  typename BLayout,
124  typename DsLayout,
125  typename CLayout,
126  typename ADataType,
127  typename AScaleDataType,
128  typename BDataType,
129  typename BScaleDataType,
130  typename AccDataType,
131  typename CShuffleDataType,
132  typename DsDataType,
133  typename CDataType,
134  typename AElementwiseOperation,
135  typename BElementwiseOperation,
136  typename CElementwiseOperation,
138  index_t ScaleBlockSize,
139  index_t BlockSize,
140  index_t MPerBlock,
141  index_t NPerBlock,
142  index_t KPerBlock,
143  index_t AK1Value,
144  index_t BK1Value,
145  index_t MPerXdl,
146  index_t NPerXdl,
147  index_t MXdlPerWave,
148  index_t NXdlPerWave,
149  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
150  typename ABlockTransferThreadClusterArrangeOrder,
151  typename ABlockTransferSrcAccessOrder,
152  index_t ABlockTransferSrcVectorDim,
153  index_t ABlockTransferSrcScalarPerVector,
154  index_t ABlockTransferDstScalarPerVector_AK1,
155  bool AThreadTransferSrcResetCoordinateAfterRun,
156  index_t ABlockLdsExtraM,
157  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
158  typename BBlockTransferThreadClusterArrangeOrder,
159  typename BBlockTransferSrcAccessOrder,
160  index_t BBlockTransferSrcVectorDim,
161  index_t BBlockTransferSrcScalarPerVector,
162  index_t BBlockTransferDstScalarPerVector_BK1,
163  bool BThreadTransferSrcResetCoordinateAfterRun,
164  index_t BBlockLdsExtraN,
165  index_t CShuffleMXdlPerWavePerShuffle,
166  index_t CShuffleNXdlPerWavePerShuffle,
167  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
168  typename CDEShuffleBlockTransferScalarPerVectors,
171  index_t ActivationOperation = 0,
172  bool NSwizzle = false,
173  bool IsInputGemm = true,
174  bool MulRoutedWeight = true,
175  typename IndexType = index_t,
176  typename ComputeTypeA = ADataType,
177  typename ComputeTypeB = BDataType>
179 {
180  using LDSTypeA = ADataType;
181  using LDSTypeB = BDataType;
182 
183  static constexpr auto I0 = Number<0>{};
184  static constexpr auto I1 = Number<1>{};
185  static constexpr auto I2 = Number<2>{};
186  static constexpr auto I3 = Number<3>{};
187  static constexpr auto I4 = Number<4>{};
188  static constexpr auto I5 = Number<5>{};
189  static constexpr auto I6 = Number<6>{};
190  static constexpr auto I7 = Number<7>{};
191  static constexpr auto I8 = Number<8>{};
192  static constexpr auto I9 = Number<9>{};
193 
195  CDEShuffleBlockTransferScalarPerVectors{}[I0];
196  // K1 should be Number<...>
197  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
198  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
199  static constexpr auto AK1Number = Number<AK1Value>{};
200  static constexpr auto BK1Number = Number<BK1Value>{};
201 
202  static constexpr index_t NumDTensor = DsDataType::Size();
203 
204  static constexpr auto MXdlPack = 2;
205  static constexpr auto NXdlPack = 2;
206  static constexpr auto KXdlPack = 2;
207 
208  static constexpr index_t APackedSize = packed_size_v<ADataType>;
209  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
210 
211  static constexpr bool is_single_rate_mfma = false;
212  static constexpr auto is_scale_mfma = true;
213  using mfma_selector = MfmaSelector<ComputeTypeA,
214  MPerXdl,
215  NPerXdl,
216  ComputeTypeB,
218  is_scale_mfma>;
219  static constexpr index_t KPack = math::max(
221 
222  // static constexpr index_t NumTokens = 1;
223  static constexpr index_t SortedTileSize = MPerBlock;
224 
225  static constexpr auto MakeDsGridPointer()
226  {
227  return generate_tuple(
228  [&](auto i) {
229  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
230 
231  return static_cast<const DDataType*>(nullptr);
232  },
234  }
235 
236  using DsGridPointer = decltype(MakeDsGridPointer());
237 
239 
240  __host__ static auto CalculateGridSize(index_t M, index_t N)
241  {
242  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
243  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
244  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
245  const index_t gridy = NSwizzle ? 1 : mblock;
246 
247  return std::make_tuple(gridx, gridy, 1);
248  }
249 
250  __host__ static auto CalculateMPadded(index_t M)
251  {
252  return math::integer_least_multiple(M, MPerBlock);
253  }
254 
255  __host__ static auto CalculateNPadded(index_t N)
256  {
257  return math::integer_least_multiple(N, NPerBlock);
258  }
259 
260  __host__ static auto CalculateKPadded(index_t K)
261  {
262  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
263  }
264 
265  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
266  {
267  auto K_t = K_Batch * KPerBlock;
268  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
269  }
270 
271  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
272  {
273  auto K_t = K_Batch * KPerBlock;
274  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
275  }
276 
277  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
278  {
279  auto K_t = K_Batch * KPerBlock;
280  return (K + K_t - 1) / K_t * KPerBlock;
281  }
282 
283  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
284  {
285  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
286  auto K_t = K_Batch * KReadVec;
287  return (K + K_t - 1) / K_t * KReadVec;
288  }
289 
290  __host__ static auto CalculateMBlock(index_t M)
291  {
292  return math::integer_divide_ceil(M, MPerBlock);
293  }
294 
295  __host__ static auto CalculateNBlock(index_t N)
296  {
297  return math::integer_divide_ceil(N, NPerBlock);
298  }
299 
300  template <index_t MNXdlPerWave,
301  index_t MNWaves,
302  index_t MNXdlPack,
303  index_t MNPerXdl,
304  typename TileDesc_K0_MN_K1>
305  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
306  {
307  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
308  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
309 
311  TileDesc_K0_MN_K1{},
314  Number<MNWaves>{},
316  Number<MNPerXdl>{}))),
319  }
320 
321  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
322  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
323  {
324  const auto a_grid_desc_mraw_kraw = [&]() {
325  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
326  {
327  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
328  }
329  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
330  {
331  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
332  }
333  }();
334 
336 
337  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
338  GemmSpec == GemmSpecialization::MNKPadding)
339  {
340  // pad both M and K
341  const auto a_grid_desc_m_k =
342  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
344  make_right_pad_transform(K, KPad - K)),
347 
348  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
349  a_grid_desc_m_k,
354 
355  return a_grid_desc_ak0_m_ak1;
356  }
357  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
358  GemmSpec == GemmSpecialization::MNPadding)
359  {
360  // pad M, but not K
361  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
362  a_grid_desc_mraw_kraw,
364  make_right_pad_transform(M, MPad - M)),
367 
368  return a_grid_desc_ak0_m_ak1;
369  }
370  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
371  GemmSpec == GemmSpecialization::NKPadding)
372  {
373  // pad K, but not M
374  const auto a_grid_desc_m_k = transform_tensor_descriptor(
375  a_grid_desc_mraw_kraw,
379 
380  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
381  a_grid_desc_m_k,
386 
387  return a_grid_desc_ak0_m_ak1;
388  }
389  else
390  {
391  // not pad M or K
392  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
393  a_grid_desc_mraw_kraw,
398 
399  return a_grid_desc_ak0_m_ak1;
400  }
401  }
402 
403  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
404  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
405  {
406  const auto b_grid_desc_nraw_kraw = [&]() {
408  {
409  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
410  }
412  {
413  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
414  }
415  }();
416 
418 
419  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
420  GemmSpec != GemmSpecialization::Default),
421  "pk_i4_t does not support padding");
422  static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
423  GemmSpec != GemmSpecialization::Default),
424  "f4x2_pk_t does not support padding");
425 
426  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
427  GemmSpec == GemmSpecialization::MNKPadding)
428  {
429  // pad both N and K
430  const auto b_grid_desc_n_k =
431  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
433  make_right_pad_transform(K, KPad - K)),
436 
437  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
438  b_grid_desc_n_k,
443 
444  return b_grid_desc_bk0_n_bk1;
445  }
446  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
447  GemmSpec == GemmSpecialization::MNPadding)
448  {
449  // pad N, but not K
450  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
451  b_grid_desc_nraw_kraw,
453  make_right_pad_transform(N, NPad - N)),
456 
457  return b_grid_desc_bk0_n_bk1;
458  }
459  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
460  GemmSpec == GemmSpecialization::MKPadding)
461  {
462  // pad K, but not N
463  const auto b_grid_desc_n_k = transform_tensor_descriptor(
464  b_grid_desc_nraw_kraw,
468 
469  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
470  b_grid_desc_n_k,
475 
476  return b_grid_desc_bk0_n_bk1;
477  }
478  else
479  {
480  // not pad N or K
481  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
482  b_grid_desc_nraw_kraw,
487 
488  return b_grid_desc_bk0_n_bk1;
489  }
490  }
491 
492  template <typename ABlockDesc_AK0_M_AK1>
493  __host__ __device__ static constexpr auto
494  MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
495  {
496  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
497 
498  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl>(
499  ABlockDesc_AK0_M_AK1{});
500  }
501 
502  template <typename BBlockDesc_BK0_N_BK1>
503  __host__ __device__ static constexpr auto
504  MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
505  {
506  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
507 
508  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl>(
509  BBlockDesc_BK0_N_BK1{});
510  }
511 
512  template <typename ELayout>
513  __host__ __device__ static auto MakeCGridDescriptor_M_N(
514  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
515  {
516  const auto c_grid_desc_mraw_nraw = [&]() {
518  {
519  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
520  }
522  {
523  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
524  }
525  }();
526 
527  // pad M and N
528  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
530  make_right_pad_transform(N, NPad - N)),
533  }
534 
535  template <typename DLayout>
536  __host__ __device__ static auto
538  {
539  const auto c_grid_desc_mraw_nraw = [&]() {
541  {
542  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
543  }
545  {
546  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
547  }
548  }();
549 
550  // pad M and N
551  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
553  make_right_pad_transform(N, NPad - N)),
556  }
557 
558  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
559  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
560  {
561  return generate_tuple(
562  [&](auto i) {
563  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
564  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
565  },
567  }
568 
569  template <typename DsGridDesc>
571  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
572  {
573  return generate_tuple(
574  [&](auto i) {
576  ds_grid_desc_m_n[i], MBlock, NBlock);
577  },
579  }
580 
581  struct Problem
582  {
583  __host__ Problem(index_t NumTokens_,
584  index_t TopK_,
585  index_t M_,
586  index_t N_,
587  index_t K_,
588  index_t StrideA_,
589  index_t StrideScaleA_,
590  index_t StrideB_,
591  index_t StrideScaleB_,
592  std::array<index_t, NumDTensor> StrideDs_,
593  index_t StrideC_,
594  index_t KBatch_)
595  : NumTokens{NumTokens_},
596  TopK{TopK_},
597  M{M_},
598  N{N_},
599  K{K_},
600  StrideA{StrideA_},
601  StrideScaleA{StrideScaleA_},
602  StrideB{StrideB_},
603  StrideScaleB{StrideScaleB_},
604  StrideDs{StrideDs_},
605  StrideC{StrideC_},
606  KBatch{KBatch_},
609  KRead{CalculateKRead(K_, KBatch_)},
610  KPadded{CalculateKPadded(K_, KBatch_)},
611  AK0{CalculateAK0Padded(K_, KBatch_)},
612  BK0{CalculateBK0Padded(K_, KBatch_)},
613  MBlock{CalculateMBlock(M_)},
615  {
616  }
617 
618  __host__ void Print() const
619  {
620  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
621  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
622  << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
623  << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
624  << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
625  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
626  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
627  << ", " << "NBlock: " << NBlock << "}" << std::endl;
628  }
629 
639  std::array<index_t, NumDTensor> StrideDs;
650  };
651 
652  // Argument
654  {
655  __host__ Argument(const index_t* p_sorted_token_ids_,
656  const index_t* p_sorted_expert_ids_,
657  const index_t* p_max_token_id_,
658  const ADataType* p_a_grid_,
659  const AScaleDataType* p_a_scale_grid_,
660  const BDataType* p_b_grid_,
661  const BScaleDataType* p_b_scale_grid_,
662  std::array<const void*, NumDTensor> p_ds_grid_,
663  CDataType* p_c_grid_,
664  index_t NumTokens_,
665  index_t TopK_,
666  index_t M_,
667  index_t N_,
668  index_t K_,
669  index_t StrideA_,
670  index_t StrideScaleA_,
671  index_t StrideB_,
672  index_t StrideScaleB_,
673  std::array<index_t, NumDTensor> StrideDs_,
674  index_t StrideC_,
675  index_t k_batch_,
676  AElementwiseOperation a_element_op_,
677  BElementwiseOperation b_element_op_,
678  CElementwiseOperation c_element_op_)
679  : Problem{NumTokens_,
680  TopK_,
681  M_,
682  N_,
683  K_ / APackedSize,
684  StrideA_ / APackedSize,
685  StrideScaleA_,
686  StrideB_ / BPackedSize,
687  StrideScaleB_,
688  StrideDs_,
689  StrideC_,
690  k_batch_},
691  p_sorted_token_ids{p_sorted_token_ids_},
692  p_sorted_expert_ids{p_sorted_expert_ids_},
693  p_max_token_id{p_max_token_id_},
694  p_a_grid{p_a_grid_},
695  p_a_scale_grid{p_a_scale_grid_},
696  p_b_grid{p_b_grid_},
697  p_b_scale_grid{p_b_scale_grid_},
698  p_ds_grid{},
699  p_c_grid{p_c_grid_},
700  a_element_op{a_element_op_},
701  b_element_op{b_element_op_},
702  c_element_op{c_element_op_}
703  {
704 
705  // populate pointer, desc for Ds
706  static_for<0, NumDTensor, 1>{}([&](auto i) {
707  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
708 
709  // D pointer
710  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
711  });
712  }
713 
717  const ADataType* p_a_grid;
718  const AScaleDataType* p_a_scale_grid;
719  const BDataType* p_b_grid;
720  const BScaleDataType* p_b_scale_grid;
722  CDataType* p_c_grid;
723 
724  const AElementwiseOperation a_element_op;
725  const BElementwiseOperation b_element_op;
726  const CElementwiseOperation c_element_op;
727  };
728 
730  {
731  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
732  {
733  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
734  {
735  a_k_split_offset = k_id * karg.KRead;
736  }
737  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
738  {
739  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
740  }
741 
742  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
743  {
744  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
745  }
746  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
747  {
748  // KPack * NLane * KLane * K0 * N0
749  b_k_split_offset = k_id * karg.KRead;
750  }
751 
752  // Calculate A scale offset
753  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
754  {
755  a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
756  }
757  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
758  {
760  k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
761  }
762 
763  // Calculate B scale offset
764  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
765  {
767  k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
768  }
769  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
770  {
771  b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
772  }
773 
774  if(k_id < karg.KBatch - 1)
775  {
776  karg.K = karg.KRead;
777  }
778  else
779  {
780  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
781  }
782  }
783 
788  };
789 
790  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
791  {
792  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
793  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
794  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
795 
796  // A matrix in LDS memory, dst of blockwise copy
797  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
798  {
802  }
803  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
804  // in some cases.
806  {
807  constexpr auto a_lds_block_desc =
810 
811  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
812  a_lds_block_desc,
818 
819  return a_lds_block_desc_permuted;
820  }
821  else // ColumnMajor A
822  {
823  // kfold and mpair dimension is not always required.
824  // more dimension in merge_transform increase the difficulty of generating immarg offset
825  // for compiler.
826  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
827  constexpr auto M1 = MPerBlock / M0;
828 
829  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
830  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
831  constexpr auto KThreadRead = WaveSize / MPerXdl;
832  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
833 
834  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
835  ? 1
836  : 128 / (AK1Number * M0 * sizeof(ADataType));
837  constexpr auto KThreadReadPerm =
838  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
839  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
840  : KThreadRead;
841 
842  // 1<=mpair<=n0
843  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
844  ? 1
845  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
846  ? M0
847  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
848 
849  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
853  Number<kfold * M0 / mpair>{},
854  Number<mpair>{},
855  AK1Number));
856 
857  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
858  a_lds_block_desc,
859  make_tuple(
863  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
866  make_tuple(
868  make_tuple(
870 
871  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
872  a_lds_block_desc_permuted,
873  make_tuple(
881  Sequence<1>{},
882  Sequence<2>{},
883  Sequence<3>{},
884  Sequence<4>{},
885  Sequence<5>{}),
887  Sequence<2>{},
888  Sequence<0, 3>{},
889  Sequence<4, 5>{},
890  Sequence<6>{},
891  Sequence<7>{}));
892 
893  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
894  a_lds_block_desc_unmerged,
897  Number<KThreadWrite / kfold / KThreadReadPerm>{},
898  Number<kfold>{},
905 
906  return a_lds_block_desc_ak0_m_ak1;
907  }
908  }
909 
910  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
911  {
912  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
913  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
914  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
915 
916  // B matrix in LDS memory, dst of blockwise copy
917  if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
918  {
922  }
924  {
925  // NLdsLayer * K0 as logical Bank
926  constexpr auto b_lds_block_desc =
929 
930  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
931  b_lds_block_desc,
937 
938  return b_lds_block_desc_permuted;
939  }
940  else // RowMajor B
941  {
942  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
943  constexpr auto N1 = NPerBlock / N0;
944 
945  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
946  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
947  constexpr auto KThreadRead = WaveSize / NPerXdl;
948  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
949 
950  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
951  ? 1
952  : 128 / (BK1Number * N0 * sizeof(BDataType));
953  constexpr auto KThreadReadPerm =
954  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
955  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
956  : KThreadRead;
957 
958  // 1<=npair<=n0
959  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
960  ? 1
961  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
962  ? N0
963  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
964 
965  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
969  Number<kfold * N0 / npair>{},
970  Number<npair>{},
971  BK1Number));
972 
973  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
974  b_lds_block_desc,
975  make_tuple(
979  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
982  make_tuple(
984  make_tuple(
986 
987  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
988  b_lds_block_desc_permuted,
989  make_tuple(
997  Sequence<1>{},
998  Sequence<2>{},
999  Sequence<3>{},
1000  Sequence<4>{},
1001  Sequence<5>{}),
1003  Sequence<2>{},
1004  Sequence<0, 3>{},
1005  Sequence<4, 5>{},
1006  Sequence<6>{},
1007  Sequence<7>{}));
1008 
1009  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1010  b_lds_block_desc_unmerged,
1013  Number<KThreadWrite / kfold / KThreadReadPerm>{},
1014  Number<kfold>{},
1021 
1022  return b_lds_block_desc_bk0_n_bk1;
1023  }
1024  }
1025 
1027  {
1028  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1029  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1030 
1031  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1033  make_tuple(I1,
1035  I1,
1037 
1038  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1039  }
1040 
1043  BlkGemmPipelineVer,
1044  BlkGemmPipeSched,
1045  BlockSize,
1046  ScaleBlockSize,
1047  ADataType,
1048  AScaleDataType,
1049  BDataType,
1050  BScaleDataType,
1051  ComputeTypeA,
1052  AccDataType,
1059  ABlockTransferSrcScalarPerVector,
1060  BBlockTransferSrcScalarPerVector,
1061  MPerBlock,
1062  NPerBlock,
1063  KPerBlock,
1064  MPerXdl,
1065  NPerXdl,
1066  MXdlPerWave,
1067  NXdlPerWave,
1068  KPack,
1069  IsInputGemm>())>;
1070 
1071  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1072  {
1073  // LDS allocation for A and B: be careful of alignment
1074  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1075  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1076 
1077  // lds max alignment
1078  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1079 
1080  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1081  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1082 
1083  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1084  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1085 
1086  // LDS allocation for C shuffle in LDS
1087  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1089 
1090  constexpr auto c_block_size =
1091  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1092 
1093  if constexpr(IsInputGemm)
1094  {
1095  return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1096  b_block_space_size_aligned * sizeof(BDataType)) *
1097  2,
1098  c_block_size * sizeof(CShuffleDataType));
1099  }
1100  else
1101  {
1102  return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1103  b_block_space_size_aligned * sizeof(BDataType)),
1104  c_block_size * sizeof(CShuffleDataType));
1105  }
1106  }
1107 
1109 
1110  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1111  __host__ static constexpr bool CheckValidity(const Argument& karg)
1112  {
1113  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1114  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1115  "Invalid tuning param!");
1116 
1117  static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1118  "KPerBlock should be multiple of ScaleBlockSize");
1119 
1120  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1125  {
1126  if(!(karg.M % MPerBlock == 0))
1127  {
1128  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1129  {
1130  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1131  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1132  << std::endl;
1133  }
1134  return false;
1135  }
1136  }
1137 
1138  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1143  {
1144  if(!(karg.N % NPerBlock == 0))
1145  {
1146  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1147  {
1148  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1149  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1150  << std::endl;
1151  }
1152  return false;
1153  }
1154  }
1155 
1156  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1160  {
1161  auto K_t = karg.KBatch * KPerBlock;
1162  if(!(karg.K % K_t == 0))
1163  {
1164  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1165  {
1166  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1167  << karg.K << " " << __FILE__ << ":" << __LINE__
1168  << ", in function: " << __func__ << std::endl;
1169  }
1170  return false;
1171  }
1172  }
1173  else
1174  {
1175  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1176  auto K_t = karg.KBatch * KReadVec;
1177  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1178  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1179  {
1180  return false;
1181  }
1182  }
1183 
1185  {
1186  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1187  {
1188  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1189  {
1190  std::cout << "Arg K (" << karg.K
1191  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1192  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1193  << __LINE__ << ", in function: " << __func__ << std::endl;
1194  }
1195  return false;
1196  }
1197  }
1198  else
1199  {
1200  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1201  {
1202  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1203  {
1204  std::cout << "Arg M (" << karg.M
1205  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1206  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1207  << __LINE__ << ", in function: " << __func__ << std::endl;
1208  }
1209  return false;
1210  }
1211  }
1212 
1214  {
1215  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1216  {
1217  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1218  {
1219  std::cout << "Arg N (" << karg.N
1220  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1221  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1222  << __LINE__ << ", in function: " << __func__ << std::endl;
1223  }
1224  return false;
1225  }
1226  }
1227  else
1228  {
1229  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1230  {
1231  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1232  {
1233  std::cout << "Arg K (" << karg.K
1234  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1235  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1236  << __LINE__ << ", in function: " << __func__ << std::endl;
1237  }
1238  return false;
1239  }
1240  }
1241 
1243  {
1245  {
1246  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1247  {
1248  std::cout << "Arg N (" << karg.N
1249  << ") value is not a multiple of "
1250  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1252  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1253  << std::endl;
1254  }
1255  return false;
1256  }
1257  }
1258  else
1259  {
1261  {
1262  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1263  {
1264  std::cout << "Arg M (" << karg.M
1265  << ") value is not a multiple of "
1266  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1268  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1269  << std::endl;
1270 
1271  return false;
1272  }
1273  }
1274  }
1275 
1276  // check gridwise gemm pipeline
1277 #if 0
1278  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1279 
1280  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1281  {
1282  return false;
1283  }
1284 #endif
1285  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1286  return true;
1287  }
1288 
1289  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1290  {
1291  const index_t num_loop = K / KPerBlock;
1292 
1293  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1294  }
1295 
1296  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1297  {
1298  const index_t num_loop = K / KPerBlock;
1299 
1300  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1301  }
1302 
1303  template <typename CGridDesc>
1304  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1305  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1306  {
1307  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1308  c_grid_desc_m_n,
1313 
1314  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1315  }
1316 
1317  // return block_id to C matrix tile idx (m0, n0) mapping
1318  // if arch = gfx942
1319  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1320  // NPerBlock>;
1321 
1323  static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
1324  static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
1325  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
1326  "A scale pack data type too large!");
1327  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
1328  "B scale pack data type too large!");
1329 
1330  template <bool HasMainKBlockLoop,
1331  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1332  TailNumber TailNum = TailNumber::Odd>
1333  __device__ static void Run(const index_t* p_sorted_token_ids,
1334  const index_t* p_sorted_expert_ids,
1335  const index_t* p_max_token_id,
1336  const ADataType* p_a_grid,
1337  const AScaleDataType* p_a_scale_grid,
1338  const BDataType* p_b_grid,
1339  const BScaleDataType* p_b_scale_grid,
1340  DsGridPointer& p_ds_grid,
1341  CDataType* p_c_grid,
1342  void* p_shared,
1343  const Problem& problem,
1344  AElementwiseOperation a_element_op,
1345  BElementwiseOperation b_element_op,
1346  CElementwiseOperation c_element_op)
1347  {
1348  ignore = b_element_op;
1349  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1350  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1351  problem.MPadded,
1352  problem.K,
1353  problem.KPadded,
1354  problem.StrideA,
1355  problem.AK0);
1356  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1357  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1358  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1359  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1360  problem.MPadded,
1361  problem.N,
1362  problem.NPadded,
1363  problem.StrideC);
1364 
1365  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1366  make_tuple(problem.M / (MXdlPack * MPerXdl),
1367  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1368  (KXdlPack * 64 / MPerXdl),
1369  64 * KXdlPack * MXdlPack / scale_pack_size_a));
1370 
1371  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1372  make_tuple(problem.N / (NXdlPack * NPerXdl),
1373  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1374  (KXdlPack * 64 / NPerXdl),
1375  64 * KXdlPack * NXdlPack / scale_pack_size_b));
1376 
1377  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1379  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1380 
1381  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1382  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1383  if(expert_block_id * MPerBlock >= max_token_id)
1384  return;
1385  const index_t expert_id =
1386  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1387 
1388  const auto block_mn = [&]() -> std::pair<int, int> {
1389  if constexpr(NSwizzle)
1390  {
1391  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1392  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1393  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1394  const index_t expert_swizzle =
1395  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1396  const index_t bid_new = blockIdx.x - prefix_block;
1397  const index_t nid = __builtin_amdgcn_readfirstlane(
1398  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1399  const index_t mid =
1400  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1401  return {nid, mid};
1402  }
1403  else
1404  {
1405  return {blockIdx.x, blockIdx.y};
1406  }
1407  }();
1408 
1409  const index_t block_n_id = block_mn.first;
1410  const index_t block_m_id = block_mn.second;
1411  const index_t token0 =
1412  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1413 
1414  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1415  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1416  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1417  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1418  constexpr auto AKThreads = AK0Threads * AK1Threads;
1419  constexpr auto AMRepeats = MPerBlock / AMThreads;
1420  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1421 
1422  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1423  return;
1425  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1426  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1427  index_t token_offset = fused_token & 0xffffff;
1428  if constexpr(!IsInputGemm)
1429  {
1430  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1431  }
1432  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1433  });
1434 
1435  const index_t expert_stride =
1436  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1437  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1438  problem.N * (IsInputGemm ? 2 : 1) *
1439  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
1440 
1441  // N0, K0, Blocksize*KPack
1442  const index_t n_block_data_idx_on_grid =
1443  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1444 
1445  // Gride buffer creation
1446  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1447  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1448  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1449  p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1450 
1451  // A, B scale buffer
1452  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1453  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1454  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1455  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
1456  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1457 
1458  // lds max alignment
1459  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1460 
1461  // A matrix in LDS memory, dst of blockwise copy
1462  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1463 
1464  // B matrix in LDS memory, dst of blockwise copy
1465  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1466 
1467  // A matrix blockwise copy
1468  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1470  AElementwiseOperation,
1474  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1475  ABlockTransferThreadClusterArrangeOrder,
1476  ADataType,
1477  ADataType,
1478  decltype(a_grid_desc_ak0_m_ak1),
1479  decltype(a_block_desc_ak0_m_ak1),
1480  ABlockTransferSrcAccessOrder,
1482  ABlockTransferSrcVectorDim,
1483  2,
1484  ABlockTransferSrcScalarPerVector,
1485  ABlockTransferDstScalarPerVector_AK1,
1486  1,
1487  1,
1488  AThreadTransferSrcResetCoordinateAfterRun,
1489  true,
1490  IndexType,
1491  1,
1492  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1493  make_multi_index(0, 0, 0),
1494  a_element_op,
1495  a_block_desc_ak0_m_ak1,
1496  make_multi_index(0, 0, 0),
1498  gather_offsets);
1499 
1500  // B matrix blockwise copy
1501  auto b_blockwise_copy =
1503  BElementwiseOperation,
1507  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1508  BBlockTransferThreadClusterArrangeOrder,
1509  BDataType,
1510  BDataType,
1511  decltype(b_grid_desc_bk0_n_bk1),
1512  decltype(b_block_desc_bk0_n_bk1),
1513  BBlockTransferSrcAccessOrder,
1515  BBlockTransferSrcVectorDim,
1516  2,
1517  BBlockTransferSrcScalarPerVector,
1518  BBlockTransferDstScalarPerVector_BK1,
1519  1,
1520  1,
1521  BThreadTransferSrcResetCoordinateAfterRun,
1522  true,
1523  BlockwiseGemmPipe::GlobalBufferNum>(
1524  b_grid_desc_bk0_n_bk1,
1525  make_multi_index(0, n_block_data_idx_on_grid, 0),
1526  b_element_op,
1527  b_block_desc_bk0_n_bk1,
1528  make_multi_index(0, 0, 0),
1530 
1531  // LDS allocation for A and B: be careful of alignment
1532  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1533  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1534 
1535  // Cast after lds
1536  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1537  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1538 
1539  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1540  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1541  a_block_space_size_aligned * sizeof(ADataType)),
1542  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1543 
1544  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1545  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1546 
1547  // Blockwise GEMM pipeline
1548  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1549  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1550  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1551  decltype(c_thread_buf) c_thread_buf_up;
1552 
1554  float,
1555  c_thread_buf.num_of_v_,
1556  c_thread_buf.s_per_v,
1557  true>
1558  c_thread_buf_fp32;
1559 
1560  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1561  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1562  KPerBlock);
1563 
1564  // a and b scale processing
1565  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1566  const auto waveId_m = wave_idx[I0];
1567  const auto waveId_n = wave_idx[I1];
1568 
1569  auto thread_offset_shuffled =
1570  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1571 
1572  auto a_thread_offset_m = waveId_m;
1573 
1574  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1575  AScaleDataType,
1576  AScaleDataType,
1577  decltype(a_scale_grid_desc_am_ak),
1578  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1579  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1580  Sequence<0, 1, 2>, // DimAccessOrder
1581  2, // SrcVectorDim
1582  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1583  1, // SrcScalarStrideInVector
1584  true>(a_scale_grid_desc_am_ak,
1585  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1586  0,
1587  thread_offset_shuffled / scale_pack_size_a));
1588 
1589  // B scale load
1590  auto b_thread_offset_n = waveId_n;
1591 
1592  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1593  BScaleDataType,
1594  BScaleDataType,
1595  decltype(b_scale_grid_desc_bn_ak),
1596  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1597  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1598  Sequence<0, 1, 2>, // DimAccessOrder
1599  2, // SrcVectorDim
1600  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1601  1, // SrcScalarStrideInVector
1602  true>(b_scale_grid_desc_bn_ak,
1603  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1604  0,
1605  thread_offset_shuffled / scale_pack_size_b));
1606 
1607  if constexpr(IsInputGemm)
1608  {
1609  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1610  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1611  auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1612  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1613  a_block_space_size_aligned * sizeof(ADataType) +
1614  b_block_space_size_aligned * sizeof(BDataType)),
1615  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1616 
1617  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1618  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1619  p_b_grid_up + expert_id * expert_stride,
1620  b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1621 
1622  auto b_blockwise_copy_up =
1624  BElementwiseOperation,
1628  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1629  BBlockTransferThreadClusterArrangeOrder,
1630  BDataType,
1631  BDataType,
1632  decltype(b_grid_desc_bk0_n_bk1),
1633  decltype(b_block_desc_bk0_n_bk1),
1634  BBlockTransferSrcAccessOrder,
1636  BBlockTransferSrcVectorDim,
1637  2,
1638  BBlockTransferSrcScalarPerVector,
1639  BBlockTransferDstScalarPerVector_BK1,
1640  1,
1641  1,
1642  BThreadTransferSrcResetCoordinateAfterRun,
1643  true,
1644  BlockwiseGemmPipe::GlobalBufferNum>(
1645  b_grid_desc_bk0_n_bk1,
1646  make_multi_index(0, n_block_data_idx_on_grid, 0),
1647  b_element_op,
1648  b_block_desc_bk0_n_bk1,
1649  make_multi_index(0, 0, 0),
1651 
1652  const BScaleDataType* p_b_scale_grid_up =
1653  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
1654  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1655  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
1656  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1657 
1658  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1659  BScaleDataType,
1660  BScaleDataType,
1661  decltype(b_scale_grid_desc_bn_ak),
1662  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1663  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1664  Sequence<0, 1, 2>, // DimAccessOrder
1665  2, // SrcVectorDim
1666  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1667  1, // SrcScalarStrideInVector
1668  true>(
1669  b_scale_grid_desc_bn_ak,
1670  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1671  0,
1672  thread_offset_shuffled / scale_pack_size_b));
1673 
1674  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1675  // A
1676  a_grid_desc_ak0_m_ak1,
1677  a_block_desc_ak0_m_ak1,
1678  a_blockwise_copy,
1679  a_grid_buf,
1680  a_block_buf,
1681  a_block_slice_copy_step,
1682  // Gate and Up
1683  b_grid_desc_bk0_n_bk1,
1684  b_block_desc_bk0_n_bk1,
1685  b_blockwise_copy,
1686  b_blockwise_copy_up,
1687  b_grid_buf,
1688  b_grid_buf_up,
1689  b_block_buf,
1690  b_block_buf_up,
1691  b_block_slice_copy_step,
1692  // C
1693  c_thread_buf,
1694  c_thread_buf_up,
1695  // A scale
1696  a_scale_grid_desc_am_ak,
1697  a_scale_thread_copy,
1698  a_scale_grid_buf,
1699  // Gate and Up scale
1700  b_scale_grid_desc_bn_ak,
1701  b_scale_thread_copy,
1702  b_scale_thread_copy_up,
1703  b_scale_grid_buf,
1704  b_scale_grid_buf_up,
1705  num_k_block_main_loop);
1706  }
1707  else
1708  {
1709  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1710  a_grid_desc_ak0_m_ak1, // A
1711  a_block_desc_ak0_m_ak1,
1712  a_blockwise_copy,
1713  a_grid_buf,
1714  a_block_buf,
1715  a_block_slice_copy_step,
1716  b_grid_desc_bk0_n_bk1, // B
1717  b_block_desc_bk0_n_bk1,
1718  b_blockwise_copy,
1719  b_grid_buf,
1720  b_block_buf,
1721  b_block_slice_copy_step,
1722  c_thread_buf, // C
1723  a_scale_grid_desc_am_ak, // A scale
1724  a_scale_thread_copy,
1725  a_scale_grid_buf,
1726  b_scale_grid_desc_bn_ak, // B scale
1727  b_scale_thread_copy,
1728  b_scale_grid_buf,
1729  num_k_block_main_loop);
1730  }
1731 
1732  // shuffle C and write out
1733  {
1734  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1735  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1736  "wrong!");
1737  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1738  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1739  "wrong!");
1740 
1741  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1742  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1743 
1744  // TODO: hacky, fix it!
1745  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1746  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1747 
1748  // TODO: hacky, fix it!
1749  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1750  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1751  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1752 
1753  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1754  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1755  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1756  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1757  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1758  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1759  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1760  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1761  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1762  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1763 
1764  // mul scales
1765  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1766  static_assert(M5 == 4);
1767  const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id
1768  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
1769 
1770  vector_type<float, 4> topk_weights; // for gemm2 only
1771  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
1772  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
1773  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
1774  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
1775  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
1776  const index_t m_pos = block_m_id * MPerBlock +
1777  m0 * M2 * M1 * M3 * M4 * M5 +
1778  m1 * M2 * M3 * M4 * M5 +
1779  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1780 
1781  if constexpr(MulRoutedWeight)
1782  {
1783  topk_weights =
1784  *c_style_pointer_cast<const vector_type<float, M5>*>(
1785  p_ds_grid[I2] + m_pos);
1786  }
1787  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
1788  constexpr index_t c_offset =
1789  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1790  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1791  constexpr auto cidx = Number<c_offset>{};
1792 
1793  if constexpr(IsInputGemm) // gu fusion
1794  {
1795  if constexpr(ActivationOperation ==
1797  {
1798  float gate = c_thread_buf[cidx];
1799  float up = c_thread_buf_up[cidx];
1800  if constexpr(MulRoutedWeight)
1801  {
1802  gate = gate * topk_weights.AsType<float>()[m5];
1803  up = up * topk_weights.AsType<float>()[m5];
1804  }
1806  c_thread_buf_fp32(cidx) = gate * up;
1807  }
1808  else if(ActivationOperation == Activation::gelu_and_mul)
1809  {
1810  float gate = c_thread_buf[cidx];
1811  float up = c_thread_buf_up[cidx];
1812  if constexpr(MulRoutedWeight)
1813  {
1814  gate = gate * topk_weights.AsType<float>()[m5];
1815  up = up * topk_weights.AsType<float>()[m5];
1816  }
1818  c_thread_buf_fp32(cidx) = gate * up;
1819 
1820  /*float gate = c_thread_buf[cidx];
1821  float up = c_thread_buf_up[cidx];
1822  if constexpr(MulRoutedWeight)
1823  {
1824  gate = gate * topk_weights.AsType<float>()[m5];
1825  //up = up * topk_weights.AsType<float>()[m5];
1826  }
1827  tensor_operation::element_wise::Gelu{}(gate, gate);
1828  c_thread_buf_fp32(cidx) = up;*/
1829  }
1830  }
1831  else
1832  {
1833  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1834  if constexpr(MulRoutedWeight)
1835  {
1836  c_thread_buf_fp32(cidx) =
1837  topk_weights.AsType<float>()[m5] *
1838  c_thread_buf_fp32[cidx];
1839  }
1840  }
1841  });
1842  });
1843  });
1844  });
1845  });
1846  });
1847 
1848  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1850 
1851  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1852  static_cast<CShuffleDataType*>(p_shared),
1853  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1854 
1855  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1856  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1857  make_tuple(
1861  // per shuffle
1862  M1, // M1 = MWave
1863  M2, // M2 = MXdlPack
1864  M3, // M3 * M4 * M5 = MPerXdl
1865  M4,
1866  M5)),
1870  // per shuffle
1871  N1, // N1 = NWave
1872  N2, // N2 = NXdlPack
1873  N3))), // N3 = NPerXdl
1877  Sequence<>{},
1879 
1880  // calculate origin of thread output tensor on global memory
1881  // blockwise GEMM c matrix starting index
1882  const auto c_thread_mtx_on_block =
1883  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1884 
1885  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1886  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1887 
1888  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1890  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1892  make_tuple(Sequence<0>{}));
1893 
1894  const auto m_thread_data_on_block_idx =
1895  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1896  make_multi_index(m_thread_data_on_block));
1897 
1898  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1900  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1902  make_tuple(Sequence<0>{}));
1903 
1904  const auto n_thread_data_on_block_idx =
1905  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1906  make_multi_index(n_thread_data_on_block));
1907 
1908  // shuffle: threadwise copy C from VGPR to LDS
1909  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1910  AccDataType,
1911  CShuffleDataType,
1912  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1913  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1915  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1916  CShuffleNXdlPerWavePerShuffle / NXdlPack,
1917  I1,
1918  I1,
1919  M2,
1920  N2,
1921  M3,
1922  I1,
1923  M5,
1924  I1>,
1926  9,
1927  1,
1929  1,
1930  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1931  make_multi_index(0,
1932  0,
1933  m_thread_data_on_block_idx[I1],
1934  n_thread_data_on_block_idx[I1],
1935  m_thread_data_on_block_idx[I2],
1936  n_thread_data_on_block_idx[I2],
1937  m_thread_data_on_block_idx[I3],
1938  m_thread_data_on_block_idx[I4],
1939  m_thread_data_on_block_idx[I5],
1940  n_thread_data_on_block_idx[I3]),
1942 
1943  using EDataType = CDataType;
1944 
1945  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1946  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1947 
1948  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1950  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1951 
1952  const auto ds_grid_buf = generate_tuple(
1953  [&](auto i) {
1954  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1955  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1956  },
1957  Number<NumDTensor>{});
1958 
1959  // tuple of reference to C/Ds tensor descriptors
1960  const auto c_ds_desc_refs = concat_tuple_of_reference(
1961  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1962  generate_tie([&](auto i) -> const auto& // return type should be reference
1963  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1964  Number<NumDTensor>{}));
1965 
1966  // tuple of reference to C/Ds tensor descriptors
1967  const auto c_ds_buf_refs = concat_tuple_of_reference(
1968  tie(c_shuffle_block_buf),
1969  generate_tie([&](auto i) -> const auto& // return type should be reference
1970  { return ds_grid_buf[i]; },
1971  Number<NumDTensor>{}));
1972 
1973  // tuple of starting index of C/Ds blockwise copy
1974  const auto idx_c_ds_block_begin =
1977  [&](auto) {
1978  return make_multi_index(block_m_id, 0, block_n_id, 0);
1979  // return make_multi_index(block_work_idx[I0], 0,
1980  // block_work_idx[I1], 0);
1981  },
1982  Number<NumDTensor>{}));
1983 
1984  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1985  c_grid_desc_mblock_mperblock_nblock_nperblock;
1986 
1987  using CDEBlockTransferCluster =
1988  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1989  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1990  constexpr index_t scatter_weight_idx = 3; // hack fix felix
1991  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1993  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1995  decltype(c_ds_desc_refs),
1996  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1997  CElementwiseOperation,
1998  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
1999  // Sequence support
2000  // arbitray type
2001  Sequence<1,
2002  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2003  1,
2004  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2005  CDEBlockTransferCluster,
2006  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2007  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2008  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2009  3, // index_t SrcVectorDim,
2010  3, // index_t DstVectorDim,
2011  CDEShuffleBlockTransferScalarPerVectors,
2016  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2017  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2018  IndexType,
2019  1, // ScatterDim
2020  true, // OutputScatter: false, only use scatter weights
2021  scatter_weight_idx // ScatterWeightIdx: ascale
2022  >{c_ds_desc_refs,
2023  idx_c_ds_block_begin,
2024  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2025  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2026  c_element_op};
2027 
2028  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2029  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2030 
2031  constexpr auto sfc_c_vgpr =
2032  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2033  NXdlPerWave / NXdlPack,
2034  1,
2035  1,
2036  MXdlPack,
2037  NXdlPack,
2038  M2,
2039  1,
2040  M4,
2041  1>,
2043  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2044  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2045  1,
2046  1,
2047  MXdlPack,
2048  NXdlPack,
2049  M2,
2050  1,
2051  M4,
2052  1>>{};
2053 
2054  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2055 
2056  // space filling curve for shuffled blockwise C/D/E
2057  constexpr auto sfc_cde_block =
2060  Sequence<1,
2061  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2062  1,
2063  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2064 
2065  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2066  constexpr auto EMThreads =
2067  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2068  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2069  constexpr auto ENThreads =
2070  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2071  static_for<0, num_access, 1>{}([&](auto access_id) {
2072  // make sure it's safe to write to LDS
2074 
2075  auto dstidx = sfc_cde_block.GetIndex(access_id);
2076  const index_t c_token_pos =
2077  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2078  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2079  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2080  IndexType token_offset = fused_token & 0xffffff;
2081  if constexpr(IsInputGemm)
2082  {
2083  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2084  }
2085  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2086  });
2087 
2088  block_sync_lds();
2089 
2090  // each thread write its data from VGPR to LDS
2091  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2092  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2093  c_thread_buf_fp32,
2094  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2095  c_shuffle_block_buf);
2096 
2097  // make sure it's safe to read from LDS
2098  block_sync_lds();
2099 
2100  // each block copy its data from LDS to global
2101  cde_block_copy_lds_and_global.Run(
2102  c_ds_desc_refs,
2103  c_ds_buf_refs,
2104  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2105  tie(c_grid_buf),
2106  scatter_offsets);
2107 
2108  if constexpr(access_id < num_access - 1)
2109  {
2110  constexpr auto cde_lds_and_global_step =
2111  sfc_cde_block.GetForwardStep(access_id);
2112 
2113  // move on Ds
2114  static_for<0, NumDTensor, 1>{}([&](auto i) {
2115  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2116  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2117  });
2118 
2119  // move on E
2120  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2121  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2122  I0,
2123  cde_lds_and_global_step);
2124  }
2125  });
2126  }
2127  }
2128 
2129 #if 0
2130  template <bool HasMainKBlockLoop,
2131  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2132  TailNumber TailNum = TailNumber::Odd>
2133  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
2134  const index_t* p_sorted_expert_ids,
2135  const index_t* p_max_token_id,
2136  const ADataType* p_a_grid,
2137  const AScaleDataType* p_a_scale_grid,
2138  const BDataType* p_b_grid,
2139  const BScaleDataType* p_b_scale_grid,
2140  DsGridPointer& p_ds_grid,
2141  CDataType* p_c_grid,
2142  void* p_shared,
2143  void* p_shared1,
2144  const Problem& problem,
2145  AElementwiseOperation a_element_op,
2146  BElementwiseOperation b_element_op,
2147  CElementwiseOperation c_element_op)
2148  {
2149  ignore = b_element_op;
2150  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2151  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2152  problem.MPadded,
2153  problem.K,
2154  problem.KPadded,
2155  problem.StrideA,
2156  problem.AK0);
2157  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2158  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2159  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2160  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2161  problem.MPadded,
2162  problem.N,
2163  problem.NPadded,
2164  problem.StrideC);
2165 
2166  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
2167  make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerXdl),
2168  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2169  (KXdlPack * 64 / MPerXdl),
2170  64 * KXdlPack * MXdlPack / scale_pack_size_a));
2171 
2172  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
2173  make_tuple(problem.N / (NXdlPack * NPerXdl),
2174  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2175  (KXdlPack * 64 / NPerXdl),
2176  64 * KXdlPack * NXdlPack / scale_pack_size_b));
2177 
2178  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2180  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2181  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2182  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
2183  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2184  if(expert_block_id * MPerBlock >= max_token_id)
2185  return;
2186  const index_t expert_id =
2187  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2188  const auto block_mn = [&]() -> std::pair<int, int> {
2189  if constexpr(NSwizzle)
2190  {
2191  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2192  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2193  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2194  const index_t expert_swizzle =
2195  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2196  const index_t bid_new = blockIdx.x - prefix_block;
2197  const index_t nid = __builtin_amdgcn_readfirstlane(
2198  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2199  const index_t mid =
2200  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2201  return {nid, mid};
2202  }
2203  else
2204  {
2205  return {blockIdx.x, blockIdx.y};
2206  }
2207  }();
2208 
2209  const index_t block_n_id = block_mn.first;
2210  const index_t block_m_id = block_mn.second;
2211  const index_t token0 =
2212  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2213 
2214  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2215  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2216  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2217  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2218  constexpr auto AKThreads = AK0Threads * AK1Threads;
2219  constexpr auto AMRepeats = MPerBlock / AMThreads;
2220  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2221 
2222  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2223  return;
2224  StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
2225  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2226  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2227  index_t token_offset = fused_token & 0xffffff;
2228  if constexpr(!IsInputGemm)
2229  {
2230  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2231  }
2232  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2233  });
2234 
2235  const index_t expert_stride =
2236  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2237  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2238  problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2239 
2240  // N0, K0, Blocksize*KPack
2241  const index_t n_block_data_idx_on_grid =
2242  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2243 
2244  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2245  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2246 
2247  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2248  p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2249 
2250  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2251  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2252  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2253  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2254  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2255 
2256  // A matrix in LDS memory, dst of blockwise copy
2257  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2258 
2259  // B matrix in LDS memory, dst of blockwise copy
2260  // dummy
2261  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2262  // A matrix blockwise copy
2263  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2265  AElementwiseOperation,
2268  Sequence<AK0Number, MPerBlock, AK1Number>,
2269  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2270  ABlockTransferThreadClusterArrangeOrder,
2271  ADataType,
2272  LDSTypeA,
2273  decltype(a_grid_desc_ak0_m_ak1),
2274  decltype(a_block_desc_ak0_m_ak1),
2275  ABlockTransferSrcAccessOrder,
2276  Sequence<0, 1, 2>,
2277  ABlockTransferSrcVectorDim,
2278  2,
2279  ABlockTransferSrcScalarPerVector,
2280  ABlockTransferDstScalarPerVector_AK1,
2281  1,
2282  1,
2283  AThreadTransferSrcResetCoordinateAfterRun,
2284  true,
2285  IndexType,
2286  1,
2287  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2288  make_multi_index(0, 0, 0),
2289  a_element_op,
2290  a_block_desc_ak0_m_ak1,
2291  make_multi_index(0, 0, 0),
2293  gather_offsets);
2294 
2295  // Thread-wise copy
2296  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2297  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2298  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2299  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2300  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2301  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2302 
2303  auto b_blockwise_copy =
2304  ThreadwiseTensorSliceTransfer_v2<BDataType,
2305  BDataType,
2306  decltype(b_grid_desc_bpreshuffled),
2307  decltype(b_block_desc_bk0_n_bk1),
2308  Sequence<Number<NXdlPerWave / NXdlPack>{},
2309  I1,
2310  Number<NXdlPack>{},
2311  Number<KRepeat>{},
2312  Number<BK1Value>{}>,
2313  Sequence<1, 2, 0, 3, 4>,
2314  4,
2315  BBlockTransferSrcScalarPerVector,
2316  BThreadTransferSrcResetCoordinateAfterRun,
2317  true>(
2318  b_grid_desc_bpreshuffled,
2319  make_multi_index(n_block_data_idx_on_grid,
2320  get_warp_local_1d_id() % NWave,
2321  0,
2322  0,
2323  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2324 
2325  // LDS allocation for A and B: be careful of alignment
2326  // Cast after lds
2327  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2328  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2329  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2330  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2331  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2332 
2333  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2334  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
2335 
2336  // Blockwise GEMM pipeline
2337  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2338  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2339  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2340  decltype(c_thread_buf) c_thread_buf_up;
2341 
2342  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
2343  float,
2344  c_thread_buf.num_of_v_,
2345  c_thread_buf.s_per_v,
2346  true>
2347  c_thread_buf_fp32;
2348 
2349  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2350  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2351  KPerBlock);
2352 
2353  // a and b scale processing
2354  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2355  const auto waveId_m = wave_idx[I0];
2356  const auto waveId_n = wave_idx[I1];
2357 
2358  auto thread_offset_shuffled =
2359  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2360 
2361  auto a_thread_offset_m = waveId_m;
2362 
2363  // get each thread's offset int the scale tensor
2364  const index_t token_scale_pos = block_m_id * MPerBlock;
2365  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2366  return;
2367 
2368  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2369  AScaleDataType,
2370  AScaleDataType,
2371  decltype(a_scale_grid_desc_am_ak),
2372  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2373  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2374  Sequence<0, 1, 2>, // DimAccessOrder
2375  2, // SrcVectorDim
2376  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2377  1, // SrcScalarStrideInVector
2378  true>(a_scale_grid_desc_am_ak,
2379  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2380  0,
2381  thread_offset_shuffled / scale_pack_size_a));
2382 
2383  // B scale load
2384  auto b_thread_offset_n = waveId_n;
2385 
2386  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2387  BScaleDataType,
2388  BScaleDataType,
2389  decltype(b_scale_grid_desc_bn_ak),
2390  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2391  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2392  Sequence<0, 1, 2>, // DimAccessOrder
2393  2, // SrcVectorDim
2394  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2395  1, // SrcScalarStrideInVector
2396  true>(b_scale_grid_desc_bn_ak,
2397  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2398  0,
2399  thread_offset_shuffled / scale_pack_size_b));
2400 
2401  if constexpr(IsInputGemm)
2402  {
2403  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2404  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2405  p_b_grid_up + expert_id * expert_stride / BPackedSize,
2406  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2407  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2408  BDataType,
2409  BDataType,
2410  decltype(b_grid_desc_bpreshuffled),
2411  decltype(b_block_desc_bk0_n_bk1),
2412  Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
2413  Sequence<1, 2, 0, 3>,
2414  3,
2415  BBlockTransferSrcScalarPerVector,
2416  BThreadTransferSrcResetCoordinateAfterRun,
2417  true>(b_grid_desc_bpreshuffled,
2418  make_multi_index(n_block_data_idx_on_grid,
2419  get_warp_local_1d_id() % NWave,
2420  0,
2421  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2422  const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
2423  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2424  p_b_scale_grid_up + expert_id * expert_scale_stride,
2425  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2426  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2427  BScaleDataType,
2428  BScaleDataType,
2429  decltype(b_scale_grid_desc_bn_ak),
2430  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2431  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2432  Sequence<0, 1, 2>, // DimAccessOrder
2433  2, // SrcVectorDim
2434  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2435  1, // SrcScalarStrideInVector
2436  true>(
2437  b_scale_grid_desc_bn_ak,
2438  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2439  0,
2440  thread_offset_shuffled / scale_pack_size_b));
2441 
2442  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2443  a_grid_desc_ak0_m_ak1,
2444  a_block_desc_ak0_m_ak1,
2445  a_blockwise_copy,
2446  a_grid_buf,
2447  a_block_bufs,
2448  a_block_slice_copy_step,
2449  b_grid_desc_bpreshuffled,
2450  b_block_desc_bk0_n_bk1,
2451  b_blockwise_copy,
2452  b_blockwise_copy_up,
2453  b_grid_buf,
2454  b_grid_buf_up,
2455  b_block_bufs,
2456  b_block_slice_copy_step,
2457  c_thread_buf,
2458  c_thread_buf_up,
2459  a_scale_grid_desc_am_ak,
2460  a_scale_thread_copy,
2461  a_scale_grid_buf,
2462  b_scale_grid_desc_bn_ak,
2463  b_scale_thread_copy,
2464  b_scale_thread_copy_up,
2465  b_scale_grid_buf,
2466  b_scale_grid_buf_up,
2467  num_k_block_main_loop);
2468  }
2469  else
2470  {
2471  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2472  a_grid_desc_ak0_m_ak1,
2473  a_block_desc_ak0_m_ak1,
2474  a_blockwise_copy,
2475  a_grid_buf,
2476  a_block_bufs,
2477  a_block_slice_copy_step,
2478  b_grid_desc_bpreshuffled,
2479  b_block_desc_bk0_n_bk1,
2480  b_blockwise_copy,
2481  b_grid_buf,
2482  b_block_bufs,
2483  b_block_slice_copy_step,
2484  c_thread_buf,
2485  a_scale_grid_desc_am_ak,
2486  a_scale_thread_copy,
2487  a_scale_grid_buf,
2488  b_scale_grid_desc_bn_ak,
2489  b_scale_thread_copy,
2490  b_scale_grid_buf,
2491  num_k_block_main_loop);
2492  }
2493 
2494  // shuffle C and write out
2495  {
2496  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2497  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2498  "wrong!");
2499 
2500  // TODO: hacky, fix it!
2501  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2502  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2503 
2504  // TODO: hacky, fix it!
2505  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2506  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2507  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2508 
2509  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2510  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2511  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2512  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2513  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2514  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2515  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2516  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2517 
2518  // mul scales
2519 
2520  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2521  static_assert(M4 == 4);
2522  const index_t m1 = get_warp_local_1d_id() / NWave;
2523  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
2524 
2525  vector_type<float, 4> topk_weights; // for gemm2 only
2526  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2527  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2528  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2529  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2530  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2531  if constexpr(MulRoutedWeight)
2532  {
2533  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2534  p_ds_grid[I2] + m_pos);
2535  }
2536  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2537  constexpr index_t c_offset =
2538  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2539  make_tuple(m0 / MXdlPack,
2540  n0 / NXdlPack,
2541  m0 % MXdlPack,
2542  n0 % NXdlPack,
2543  m2 * M4 + m4));
2544  constexpr auto cidx = Number<c_offset>{};
2545 
2546  if constexpr(IsInputGemm) // gu fusion
2547  {
2548  if constexpr(ActivationOperation == Activation::silu_and_mul)
2549  {
2550  float gate = c_thread_buf[cidx];
2551  float up = c_thread_buf_up[cidx];
2552  if constexpr(MulRoutedWeight)
2553  {
2554  gate = gate * topk_weights.AsType<float>()[m4];
2555  up = up * topk_weights.AsType<float>()[m4];
2556  }
2557  tensor_operation::element_wise::Silu{}(gate, gate);
2558  c_thread_buf_fp32(cidx) = gate * up;
2559  }
2560  else if(ActivationOperation == Activation::gelu_and_mul)
2561  {
2562  float gate = c_thread_buf[cidx];
2563  float up = c_thread_buf_up[cidx];
2564  if constexpr(MulRoutedWeight)
2565  {
2566  gate = gate * topk_weights.AsType<float>()[m4];
2567  up = up * topk_weights.AsType<float>()[m4];
2568  }
2569  tensor_operation::element_wise::Gelu{}(gate, gate);
2570  c_thread_buf_fp32(cidx) = gate * up;
2571  }
2572  }
2573  else
2574  {
2575  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2576  if constexpr(MulRoutedWeight)
2577  {
2578  c_thread_buf_fp32(cidx) =
2579  topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
2580  }
2581  }
2582  });
2583  });
2584  });
2585  });
2586 
2587  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2589 
2590  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2591  static_cast<CShuffleDataType*>(p_shared),
2592  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2593 
2594  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2595  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2598  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per
2599  // shuffle
2600  M1, // M1 = MWave
2601  M2, // M2 * M3 * M4 = MPerXdl
2602  M3,
2603  M4)),
2606  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per
2607  // shuffle
2608  N1, // N1 = NWave
2609  N2))), // N2 = NPerXdl
2610  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
2611  make_tuple(
2612  Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
2613 
2614  // calculate origin of thread output tensor on global memory
2615  // blockwise GEMM c matrix starting index
2616  const auto c_thread_mtx_on_block =
2617  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2618 
2619  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2620  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2621 
2622  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2624  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2625  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
2626  make_tuple(Sequence<0>{}));
2627 
2628  const auto m_thread_data_on_block_idx =
2629  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2630  make_multi_index(m_thread_data_on_block));
2631 
2632  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2635  make_tuple(Sequence<0, 1, 2>{}),
2636  make_tuple(Sequence<0>{}));
2637 
2638  const auto n_thread_data_on_block_idx =
2639  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2640  make_multi_index(n_thread_data_on_block));
2641 
2642  // shuffle: threadwise copy C from VGPR to LDS
2643  auto c_thread_copy_vgpr_to_lds =
2644  ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
2645  CShuffleDataType,
2646  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2647  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2649  Sequence<CShuffleMXdlPerWavePerShuffle,
2650  CShuffleNXdlPerWavePerShuffle,
2651  I1,
2652  I1,
2653  M2,
2654  I1,
2655  M4,
2656  I1>,
2657  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2658  7,
2659  1,
2661  1,
2662  true>{
2663  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2664  make_multi_index(0,
2665  0,
2666  m_thread_data_on_block_idx[I1],
2667  n_thread_data_on_block_idx[I1],
2668  m_thread_data_on_block_idx[I2],
2669  m_thread_data_on_block_idx[I3],
2670  m_thread_data_on_block_idx[I4],
2671  n_thread_data_on_block_idx[I2]),
2673 
2674  using EDataType = CDataType;
2675 
2676  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2677  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2678 
2679  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2681  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2682 
2683  const auto ds_grid_buf = generate_tuple(
2684  [&](auto i) {
2685  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2686  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2687  },
2688  Number<NumDTensor>{});
2689 
2690  // tuple of reference to C/Ds tensor descriptors
2691  const auto c_ds_desc_refs = concat_tuple_of_reference(
2692  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2693  generate_tie([&](auto i) -> const auto& // return type should be reference
2694  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2695  Number<NumDTensor>{}));
2696 
2697  // tuple of reference to C/Ds tensor descriptors
2698  const auto c_ds_buf_refs = concat_tuple_of_reference(
2699  tie(c_shuffle_block_buf),
2700  generate_tie([&](auto i) -> const auto& // return type should be reference
2701  { return ds_grid_buf[i]; },
2702  Number<NumDTensor>{}));
2703 
2704  // tuple of starting index of C/Ds blockwise copy
2705  const auto idx_c_ds_block_begin =
2708  [&](auto) {
2709  return make_multi_index(block_m_id, 0, block_n_id, 0);
2710  // return make_multi_index(block_work_idx[I0], 0,
2711  // block_work_idx[I1], 0);
2712  },
2713  Number<NumDTensor>{}));
2714 
2715  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2716  c_grid_desc_mblock_mperblock_nblock_nperblock;
2717 
2718  using CDEBlockTransferCluster =
2719  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2720  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2721  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2722  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2724  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2725  Tuple<EDataType>,
2726  decltype(c_ds_desc_refs),
2727  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2728  CElementwiseOperation,
2729  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2730  // Sequence support
2731  // arbitray type
2732  Sequence<1,
2733  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2734  1,
2735  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2736  CDEBlockTransferCluster,
2737  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2738  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2739  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2740  3, // index_t SrcVectorDim,
2741  3, // index_t DstVectorDim,
2742  CDEShuffleBlockTransferScalarPerVectors,
2745  Sequence<true>,
2747  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2748  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2749  IndexType,
2750  1, // ScatterDim
2751  true, // OutputScatter: false, only use scatter weights
2752  scatter_weight_idx // ScatterWeightIdx: ascale
2753  >{c_ds_desc_refs,
2754  idx_c_ds_block_begin,
2755  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2756  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2757  c_element_op};
2758 
2759  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2760  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2761  constexpr auto sfc_c_vgpr =
2762  SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
2763  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2764  Sequence<CShuffleMXdlPerWavePerShuffle,
2765  CShuffleNXdlPerWavePerShuffle,
2766  1,
2767  1,
2768  M2,
2769  1,
2770  M4,
2771  1>>{};
2772 
2773  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2774 
2775  // space filling curve for shuffled blockwise C/D/E
2776  constexpr auto sfc_cde_block =
2777  SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
2778  Sequence<0, 2, 1, 3>,
2779  Sequence<1,
2780  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2781  1,
2782  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2783 
2784  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2785  constexpr auto EMThreads =
2786  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2787  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2788  constexpr auto ENThreads =
2789  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2790  static_for<0, num_access, 1>{}([&](auto access_id) {
2791  // make sure it's safe to write to LDS
2792  StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
2793 
2794  auto dstidx = sfc_cde_block.GetIndex(access_id);
2795  const index_t c_token_pos =
2796  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2797  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2798  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2799  IndexType token_offset = fused_token & 0xffffff;
2800  if constexpr(IsInputGemm)
2801  {
2802  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2803  }
2804  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2805  });
2806 
2807  block_sync_lds();
2808 
2809  // each thread write its data from VGPR to LDS
2810  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2811  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2812  c_thread_buf_fp32,
2813  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2814  c_shuffle_block_buf);
2815 
2816  // make sure it's safe to read from LDS
2817  block_sync_lds();
2818 
2819  // each block copy its data from LDS to global
2820  cde_block_copy_lds_and_global.Run(
2821  c_ds_desc_refs,
2822  c_ds_buf_refs,
2823  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2824  tie(c_grid_buf),
2825  scatter_offsets);
2826 
2827  if constexpr(access_id < num_access - 1)
2828  {
2829  constexpr auto cde_lds_and_global_step =
2830  sfc_cde_block.GetForwardStep(access_id);
2831 
2832  // move on Ds
2833  static_for<0, NumDTensor, 1>{}([&](auto i) {
2834  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2835  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2836  });
2837 
2838  // move on E
2839  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2840  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2841  I0,
2842  cde_lds_and_global_step);
2843  }
2844  });
2845  }
2846  }
2847 #endif
2848 };
2849 
2850 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:90
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
constexpr auto BlockGemmMXNBSPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp:37
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: gridwise_moe_mx_gemm_bns.hpp:654
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:717
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm_bns.hpp:714
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm_bns.hpp:715
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm_bns.hpp:716
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:721
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:726
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:722
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:724
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:720
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:719
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:718
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm_bns.hpp:655
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:725
Definition: gridwise_moe_mx_gemm_bns.hpp:582
index_t M
Definition: gridwise_moe_mx_gemm_bns.hpp:632
index_t TopK
Definition: gridwise_moe_mx_gemm_bns.hpp:631
index_t NPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:643
index_t MPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:642
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm_bns.hpp:638
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm_bns.hpp:636
index_t MBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:648
index_t StrideC
Definition: gridwise_moe_mx_gemm_bns.hpp:640
index_t AK0
Definition: gridwise_moe_mx_gemm_bns.hpp:646
index_t KPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:645
index_t NBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:649
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm_bns.hpp:583
index_t StrideA
Definition: gridwise_moe_mx_gemm_bns.hpp:635
index_t StrideB
Definition: gridwise_moe_mx_gemm_bns.hpp:637
index_t KBatch
Definition: gridwise_moe_mx_gemm_bns.hpp:641
index_t BK0
Definition: gridwise_moe_mx_gemm_bns.hpp:647
index_t KRead
Definition: gridwise_moe_mx_gemm_bns.hpp:644
__host__ void Print() const
Definition: gridwise_moe_mx_gemm_bns.hpp:618
index_t K
Definition: gridwise_moe_mx_gemm_bns.hpp:634
index_t N
Definition: gridwise_moe_mx_gemm_bns.hpp:633
index_t NumTokens
Definition: gridwise_moe_mx_gemm_bns.hpp:630
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm_bns.hpp:639
Definition: gridwise_moe_mx_gemm_bns.hpp:730
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:784
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:785
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm_bns.hpp:731
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:787
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:786
Definition: gridwise_moe_mx_gemm_bns.hpp:179
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm_bns.hpp:212
static constexpr auto I1
Definition: gridwise_moe_mx_gemm_bns.hpp:184
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm_bns.hpp:1323
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm_bns.hpp:202
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm_bns.hpp:250
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:204
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm_bns.hpp:198
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm_bns.hpp:200
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm_bns.hpp:180
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:1289
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm_bns.hpp:211
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:238
static constexpr auto I5
Definition: gridwise_moe_mx_gemm_bns.hpp:188
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bns.hpp:403
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm_bns.hpp:181
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:1296
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm_bns.hpp:236
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bns.hpp:321
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:1111
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bns.hpp:558
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bns.hpp:1333
static constexpr auto I4
Definition: gridwise_moe_mx_gemm_bns.hpp:187
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm_bns.hpp:1071
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm_bns.hpp:208
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm_bns.hpp:225
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bns.hpp:570
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:255
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:265
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:504
static constexpr auto I8
Definition: gridwise_moe_mx_gemm_bns.hpp:191
static constexpr auto I7
Definition: gridwise_moe_mx_gemm_bns.hpp:190
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:206
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bns.hpp:1304
static constexpr auto I2
Definition: gridwise_moe_mx_gemm_bns.hpp:185
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:494
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm_bns.hpp:1324
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:283
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:305
remove_cvref_t< decltype(BlockGemmMXNBSPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm_bns.hpp:1069
static constexpr auto I6
Definition: gridwise_moe_mx_gemm_bns.hpp:189
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm_bns.hpp:219
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:271
static constexpr auto I3
Definition: gridwise_moe_mx_gemm_bns.hpp:186
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm_bns.hpp:199
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm_bns.hpp:910
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm_bns.hpp:223
static constexpr auto I9
Definition: gridwise_moe_mx_gemm_bns.hpp:192
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm_bns.hpp:537
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm_bns.hpp:197
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm_bns.hpp:1026
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm_bns.hpp:290
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:194
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:260
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:205
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:295
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm_bns.hpp:209
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm_bns.hpp:513
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:277
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm_bns.hpp:790
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:240
static constexpr auto I0
Definition: gridwise_moe_mx_gemm_bns.hpp:183
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:26
Definition: data_type.hpp:42
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129