/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp Source File
gridwise_moe_mx_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
16 
20 
21 #define DEBUG_LOG 0
22 
23 namespace ck {
24 
25 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26 // kernel function Blockers:
27 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28 // two lds chunks.
29 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30 // buffer when we declare __shared__ inside blkgemmpipe
31 
33 {
34  gelu_and_mul = 0,
35  silu_and_mul = 1
36 };
37 
38 #if 0
39 template <typename GridwiseGemm,
40  bool HasMainKBlockLoop,
41  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
42  index_t MinimumOccupancy = 1,
43  TailNumber TailNum = TailNumber::Even>
44 __global__ void
45 #if CK_USE_LAUNCH_BOUNDS
46 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
47 #endif
48  // __attribute__((amdgpu_waves_per_eu(1, 1)))
49  kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
50 {
51 #if defined(__gfx9__)
52  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
53 
54  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
55 
56  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
57  karg.p_sorted_token_ids,
58  karg.p_sorted_expert_ids,
59  karg.p_max_token_id,
60  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
61  karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
62  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
63  karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
64  karg.p_ds_grid,
65  karg.p_c_grid,
66  p_shared,
67  karg,
68  karg.a_element_op,
69  karg.b_element_op,
70  karg.c_element_op);
71 #else
72  ignore = karg;
73 #endif // end of if (defined(__gfx9__))
74 }
75 #endif
76 
77 template <typename GridwiseGemm,
78  bool HasMainKBlockLoop,
79  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
80  index_t MinimumOccupancy = 1,
81  TailNumber TailNum = TailNumber::Even>
82 __global__ void
83 #if CK_USE_LAUNCH_BOUNDS
84 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
85 #endif
86  // __attribute__((amdgpu_waves_per_eu(1, 1)))
87  kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
88 {
89 #if defined(__gfx9__)
90  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
91  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
92 
93  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
94 
95  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
96  karg.p_sorted_token_ids,
97  karg.p_sorted_expert_ids,
98  karg.p_max_token_id,
99  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
100  karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
101  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
102  karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
103  karg.p_ds_grid,
104  karg.p_c_grid,
105  p_shared_0,
106  p_shared_1,
107  karg,
108  karg.a_element_op,
109  karg.b_element_op,
110  karg.c_element_op);
111 #else
112  ignore = karg;
113 #endif // end of if (defined(__gfx9__))
114 }
115 
116 template <typename ALayout,
117  typename BLayout,
118  typename DsLayout,
119  typename CLayout,
120  typename ADataType,
121  typename AScaleDataType,
122  typename BDataType,
123  typename BScaleDataType,
124  typename AccDataType,
125  typename CShuffleDataType,
126  typename DsDataType,
127  typename CDataType,
128  typename AElementwiseOperation,
129  typename BElementwiseOperation,
130  typename CElementwiseOperation,
132  index_t ScaleBlockSize, // Scaling block size
133  index_t BlockSize, // Thread block size
134  index_t MPerBlock,
135  index_t NPerBlock,
136  index_t KPerBlock,
137  index_t AK1Value,
138  index_t BK1Value,
139  index_t MPerXdl,
140  index_t NPerXdl,
141  index_t MXdlPerWave,
142  index_t NXdlPerWave,
143  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
144  typename ABlockTransferThreadClusterArrangeOrder,
145  typename ABlockTransferSrcAccessOrder,
146  index_t ABlockTransferSrcVectorDim,
147  index_t ABlockTransferSrcScalarPerVector,
148  index_t ABlockTransferDstScalarPerVector_AK1,
149  bool AThreadTransferSrcResetCoordinateAfterRun,
150  index_t ABlockLdsExtraM,
151  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
152  typename BBlockTransferThreadClusterArrangeOrder,
153  typename BBlockTransferSrcAccessOrder,
154  index_t BBlockTransferSrcVectorDim,
155  index_t BBlockTransferSrcScalarPerVector,
156  index_t BBlockTransferDstScalarPerVector_BK1,
157  bool BThreadTransferSrcResetCoordinateAfterRun,
158  index_t BBlockLdsExtraN,
159  index_t CShuffleMXdlPerWavePerShuffle,
160  index_t CShuffleNXdlPerWavePerShuffle,
161  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162  typename CDEShuffleBlockTransferScalarPerVectors,
165  index_t ActivationOperation = 0,
166  bool NSwizzle = false,
167  bool IsInputGemm = true,
168  bool MulRoutedWeight = true,
169  typename IndexType = index_t,
170  typename ComputeTypeA = ADataType,
171  typename ComputeTypeB = BDataType>
173 {
174  using LDSTypeA = ADataType;
175  using LDSTypeB = BDataType;
176 
177  static constexpr auto I0 = Number<0>{};
178  static constexpr auto I1 = Number<1>{};
179  static constexpr auto I2 = Number<2>{};
180  static constexpr auto I3 = Number<3>{};
181  static constexpr auto I4 = Number<4>{};
182  static constexpr auto I5 = Number<5>{};
183  static constexpr auto I6 = Number<6>{};
184  static constexpr auto I7 = Number<7>{};
185  static constexpr auto I8 = Number<8>{};
186  static constexpr auto I9 = Number<9>{};
187 
189  CDEShuffleBlockTransferScalarPerVectors{}[I0];
190  // K1 should be Number<...>
191  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
192  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
193  static constexpr auto AK1Number = Number<AK1Value>{};
194  static constexpr auto BK1Number = Number<BK1Value>{};
195 
196  static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
197  static constexpr bool is_single_rate_mfma = false;
198  static constexpr auto is_scale_mfma = true;
199 
200  static constexpr index_t NumDTensor = DsDataType::Size();
201 
202  static constexpr auto MXdlPack = 2;
203  static constexpr auto NXdlPack = 2;
204  static constexpr auto KXdlPack = 2;
205 
206  //> KPack is at least the k_per_blk of selected mfma
207  //
208  // Should be a multiple of k_per_blk.
209  // TODO: Move this to blockwise pipeline base
210  // KPack in packed data types for pk A/B
211 
212  static constexpr index_t APackedSize = packed_size_v<ADataType>;
213  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
214 
215  using mfma_selector = MfmaSelector<ComputeTypeA,
216  MPerXdl,
217  NPerXdl,
218  ComputeTypeB,
220  is_scale_mfma>;
221  static constexpr index_t KPack =
223 
224  // static constexpr index_t NumTokens = 1;
225  static constexpr index_t SortedTileSize = MPerBlock;
226 
227  static constexpr auto MakeDsGridPointer()
228  {
229  return generate_tuple(
230  [&](auto i) {
231  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
232 
233  return static_cast<const DDataType*>(nullptr);
234  },
236  }
237 
238  using DsGridPointer = decltype(MakeDsGridPointer());
239 
241 
242  __host__ static auto CalculateGridSize(index_t M, index_t N)
243  {
244  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
245  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
246  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
247  const index_t gridy = NSwizzle ? 1 : mblock;
248 
249  return std::make_tuple(gridx, gridy, 1);
250  }
251 
252  __host__ static auto CalculateMPadded(index_t M)
253  {
254  return math::integer_least_multiple(M, MPerBlock);
255  }
256 
257  __host__ static auto CalculateNPadded(index_t N)
258  {
259  return math::integer_least_multiple(N, NPerBlock);
260  }
261 
262  __host__ static auto CalculateKPadded(index_t K)
263  {
264  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
265  }
266 
267  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
268  {
269  auto K_t = K_Batch * KPerBlock;
270  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
271  }
272 
273  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
274  {
275  auto K_t = K_Batch * KPerBlock;
276  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
277  }
278 
279  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
280  {
281  auto K_t = K_Batch * KPerBlock;
282  return (K + K_t - 1) / K_t * KPerBlock;
283  }
284 
285  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
286  {
287  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
288  auto K_t = K_Batch * KReadVec;
289  return (K + K_t - 1) / K_t * KReadVec;
290  }
291 
292  __host__ static auto CalculateMBlock(index_t M)
293  {
294  return math::integer_divide_ceil(M, MPerBlock);
295  }
296 
297  __host__ static auto CalculateNBlock(index_t N)
298  {
299  return math::integer_divide_ceil(N, NPerBlock);
300  }
301 
302  template <index_t MNXdlPerWave,
303  index_t MNWaves,
304  index_t MNXdlPack,
305  index_t MNPerXdl,
306  typename TileDesc_K0_MN_K1>
307  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
308  {
309  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
310  constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
311  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
312 
313  constexpr auto permuted_desc = transform_tensor_descriptor(
314  TileDesc_K0_MN_K1{},
319 
321  permuted_desc,
324  Number<MNWaves>{},
326  Number<MNPerXdl>{}))),
329  }
330 
331  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
332  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
333  {
334  const auto a_grid_desc_mraw_kraw = [&]() {
335  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
336  {
337  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
338  }
339  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
340  {
341  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
342  }
343  }();
344 
346 
347  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
348  GemmSpec == GemmSpecialization::MNKPadding)
349  {
350  // pad both M and K
351  const auto a_grid_desc_m_k =
352  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
354  make_right_pad_transform(K, KPad - K)),
357 
358  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
359  a_grid_desc_m_k,
364 
365  return a_grid_desc_ak0_m_ak1;
366  }
367  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
368  GemmSpec == GemmSpecialization::MNPadding)
369  {
370  // pad M, but not K
371  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
372  a_grid_desc_mraw_kraw,
373  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
374  make_right_pad_transform(M, MPad - M)),
377 
378  const auto a_grid_desc_permuted = transform_tensor_descriptor(
379  a_grid_desc_ak0_m_ak1,
382  make_pass_through_transform(AK1Value)),
385 
386  const auto a_grid_desc = transform_tensor_descriptor(
387  a_grid_desc_permuted,
388  make_tuple(
391  make_pass_through_transform(AK1Value)),
394  return a_grid_desc;
395  }
396  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
397  GemmSpec == GemmSpecialization::NKPadding)
398  {
399  // pad K, but not M
400  const auto a_grid_desc_m_k = transform_tensor_descriptor(
401  a_grid_desc_mraw_kraw,
405 
406  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
407  a_grid_desc_m_k,
412 
413  return a_grid_desc_ak0_m_ak1;
414  }
415  else
416  {
417  // not pad M or K
418  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
419  a_grid_desc_mraw_kraw,
420  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
424 
425  const auto a_grid_desc_permuted = transform_tensor_descriptor(
426  a_grid_desc_ak0_m_ak1,
429  make_pass_through_transform(AK1Value)),
432 
433  const auto a_grid_desc = transform_tensor_descriptor(
434  a_grid_desc_permuted,
435  make_tuple(
438  make_pass_through_transform(AK1Value)),
441 
442  return a_grid_desc;
443  }
444  }
445 
446  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
447  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
448  {
449  const auto b_grid_desc_nraw_kraw = [&]() {
451  {
452  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
453  }
455  {
456  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
457  }
458  }();
459 
461 
462  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
463  GemmSpec != GemmSpecialization::Default),
464  "pk_i4_t does not support padding");
465  static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
466  (GemmSpec != GemmSpecialization::Default &&
467  GemmSpec != GemmSpecialization::MPadding)),
468  "f4x2_pk_t does not support K padding");
469 
470  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
471  GemmSpec == GemmSpecialization::MNKPadding)
472  {
473  // pad both N and K
474  const auto b_grid_desc_n_k =
475  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
477  make_right_pad_transform(K, KPad - K)),
480 
481  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
482  b_grid_desc_n_k,
487 
488  return b_grid_desc_bk0_n_bk1;
489  }
490  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
491  GemmSpec == GemmSpecialization::MNPadding)
492  {
493  // pad N, but not K
494  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
495  b_grid_desc_nraw_kraw,
497  make_right_pad_transform(N, NPad - N)),
500 
501  return b_grid_desc_bk0_n_bk1;
502  }
503  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
504  GemmSpec == GemmSpecialization::MKPadding)
505  {
506  // pad K, but not N
507  const auto b_grid_desc_n_k = transform_tensor_descriptor(
508  b_grid_desc_nraw_kraw,
512 
513  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
514  b_grid_desc_n_k,
519 
520  return b_grid_desc_bk0_n_bk1;
521  }
522  else
523  {
524  // not pad N or K
525  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
526  b_grid_desc_nraw_kraw,
527  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
531 
532  const auto b_grid_desc_permuted = transform_tensor_descriptor(
533  b_grid_desc_bk0_n_bk1,
536  make_pass_through_transform(BK1Value)),
539 
540  const auto b_grid_desc = transform_tensor_descriptor(
541  b_grid_desc_permuted,
542  make_tuple(
545  make_pass_through_transform(BK1Value)),
548 
549  return b_grid_desc;
550  }
551  }
552 
553  template <typename ABlockDesc_AK0_M_AK1>
554  __host__ __device__ static constexpr auto
555  MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
556  {
557  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
558 
559  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl>(
560  ABlockDesc_AK0_M_AK1{});
561  }
562 
563  template <typename BBlockDesc_BK0_N_BK1>
564  __host__ __device__ static constexpr auto
565  MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
566  {
567  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
568 
569  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl>(
570  BBlockDesc_BK0_N_BK1{});
571  }
572 
573  template <typename ELayout>
574  __host__ __device__ static auto MakeCGridDescriptor_M_N(
575  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
576  {
577  const auto c_grid_desc_mraw_nraw = [&]() {
579  {
580  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
581  }
583  {
584  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
585  }
586  }();
587 
588  // pad M and N
589  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
591  make_right_pad_transform(N, NPad - N)),
594  }
595 
596  template <typename DLayout>
597  __host__ __device__ static auto
599  {
600  const auto c_grid_desc_mraw_nraw = [&]() {
602  {
603  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
604  }
606  {
607  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
608  }
609  }();
610 
611  // pad M and N
612  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
614  make_right_pad_transform(N, NPad - N)),
617  }
618 
619  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
620  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
621  {
622  return generate_tuple(
623  [&](auto i) {
624  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
625  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
626  },
628  }
629 
630  template <typename DsGridDesc>
632  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
633  {
634  return generate_tuple(
635  [&](auto i) {
637  ds_grid_desc_m_n[i], MBlock, NBlock);
638  },
640  }
641 
642  struct Problem
643  {
644  __host__ Problem(index_t NumTokens_,
645  index_t TopK_,
646  index_t M_,
647  index_t N_,
648  index_t K_,
649  index_t StrideA_,
650  index_t StrideScaleA_,
651  index_t StrideB_,
652  index_t StrideScaleB_,
653  std::array<index_t, NumDTensor> StrideDs_,
654  index_t StrideC_,
655  index_t KBatch_)
656  : NumTokens{NumTokens_},
657  TopK{TopK_},
658  M{M_},
659  N{N_},
660  K{K_},
661  StrideA{StrideA_},
662  StrideScaleA{StrideScaleA_},
663  StrideB{StrideB_},
664  StrideScaleB{StrideScaleB_},
665  StrideDs{StrideDs_},
666  StrideC{StrideC_},
667  KBatch{KBatch_},
670  KRead{CalculateKRead(K_, KBatch_)},
671  KPadded{CalculateKPadded(K_, KBatch_)},
672  AK0{CalculateAK0Padded(K_, KBatch_)},
673  BK0{CalculateBK0Padded(K_, KBatch_)},
674  MBlock{CalculateMBlock(M_)},
676  {
677  }
678 
679  __host__ void Print() const
680  {
681  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
682  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
683  << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
684  << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
685  << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
686  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
687  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
688  << ", " << "NBlock: " << NBlock << "}" << std::endl;
689  }
690 
700  std::array<index_t, NumDTensor> StrideDs;
711  };
712 
713  // Argument
715  {
716  __host__ Argument(const index_t* p_sorted_token_ids_,
717  const index_t* p_sorted_expert_ids_,
718  const index_t* p_max_token_id_,
719  const ADataType* p_a_grid_,
720  const AScaleDataType* p_a_scale_grid_,
721  const BDataType* p_b_grid_,
722  const BScaleDataType* p_b_scale_grid_,
723  std::array<const void*, NumDTensor> p_ds_grid_,
724  CDataType* p_c_grid_,
725  index_t NumTokens_,
726  index_t TopK_,
727  index_t M_,
728  index_t N_,
729  index_t K_,
730  index_t StrideA_,
731  index_t StrideScaleA_,
732  index_t StrideB_,
733  index_t StrideScaleB_,
734  std::array<index_t, NumDTensor> StrideDs_,
735  index_t StrideC_,
736  index_t k_batch_,
737  AElementwiseOperation a_element_op_,
738  BElementwiseOperation b_element_op_,
739  CElementwiseOperation c_element_op_)
740  : Problem{NumTokens_,
741  TopK_,
742  M_,
743  N_,
744  K_ / APackedSize,
745  StrideA_ / APackedSize,
746  StrideScaleA_,
747  StrideB_ / BPackedSize,
748  StrideScaleB_,
749  StrideDs_,
750  StrideC_,
751  k_batch_},
752  p_sorted_token_ids{p_sorted_token_ids_},
753  p_sorted_expert_ids{p_sorted_expert_ids_},
754  p_max_token_id{p_max_token_id_},
755  p_a_grid{p_a_grid_},
756  p_a_scale_grid{p_a_scale_grid_},
757  p_b_grid{p_b_grid_},
758  p_b_scale_grid{p_b_scale_grid_},
759  p_ds_grid{},
760  p_c_grid{p_c_grid_},
761  a_element_op{a_element_op_},
762  b_element_op{b_element_op_},
763  c_element_op{c_element_op_}
764  {
765 
766  // populate pointer, desc for Ds
767  static_for<0, NumDTensor, 1>{}([&](auto i) {
768  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
769 
770  // D pointer
771  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
772  });
773  }
774 
778  const ADataType* p_a_grid;
779  const AScaleDataType* p_a_scale_grid;
780  const BDataType* p_b_grid;
781  const BScaleDataType* p_b_scale_grid;
783  CDataType* p_c_grid;
784 
785  const AElementwiseOperation a_element_op;
786  const BElementwiseOperation b_element_op;
787  const CElementwiseOperation c_element_op;
788  };
789 
791  {
792  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
793  {
794  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
795  {
796  a_k_split_offset = k_id * karg.KRead;
797  }
798  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
799  {
800  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
801  }
802 
803  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
804  {
805  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
806  }
807  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
808  {
809  // KPack * NLane * KLane * K0 * N0
810  b_k_split_offset = k_id * karg.KRead;
811  }
812 
813  // Calculate A scale offset
814  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
815  {
816  a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
817  }
818  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
819  {
821  k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
822  }
823 
824  // Calculate B scale offset
825  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
826  {
828  k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
829  }
830  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
831  {
832  b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
833  }
834 
835  if(k_id < karg.KBatch - 1)
836  {
837  karg.K = karg.KRead;
838  }
839  else
840  {
841  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
842  }
843  }
844 
849  };
850 
851  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
852  {
853  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
854  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
855  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
856 
857  // A matrix in LDS memory, dst of blockwise copy
858  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
859  {
860  // contiguous in LDS
864  }
865  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
866  // in some cases.
868  {
869  constexpr auto a_lds_block_desc =
872 
873  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
874  a_lds_block_desc,
880 
881  return a_lds_block_desc_permuted;
882  }
883  else // ColumnMajor A
884  {
885  // kfold and mpair dimension is not always required.
886  // more dimension in merge_transform increase the difficulty of generating immarg offset
887  // for compiler.
888  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
889  constexpr auto M1 = MPerBlock / M0;
890 
891  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
892  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
893  constexpr auto KThreadRead = WaveSize / MPerXdl;
894  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
895 
896  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
897  ? 1
898  : 128 / (AK1Number * M0 * sizeof(ADataType));
899  constexpr auto KThreadReadPerm =
900  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
901  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
902  : KThreadRead;
903 
904  // 1<=mpair<=n0
905  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
906  ? 1
907  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
908  ? M0
909  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
910 
911  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
915  Number<kfold * M0 / mpair>{},
916  Number<mpair>{},
917  AK1Number));
918 
919  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
920  a_lds_block_desc,
921  make_tuple(
925  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
928  make_tuple(
930  make_tuple(
932 
933  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
934  a_lds_block_desc_permuted,
935  make_tuple(
943  Sequence<1>{},
944  Sequence<2>{},
945  Sequence<3>{},
946  Sequence<4>{},
947  Sequence<5>{}),
949  Sequence<2>{},
950  Sequence<0, 3>{},
951  Sequence<4, 5>{},
952  Sequence<6>{},
953  Sequence<7>{}));
954 
955  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
956  a_lds_block_desc_unmerged,
959  Number<KThreadWrite / kfold / KThreadReadPerm>{},
960  Number<kfold>{},
967 
968  return a_lds_block_desc_ak0_m_ak1;
969  }
970  }
971 
972  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
973  {
974  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
975  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
976  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
977 
978  // B matrix in LDS memory, dst of blockwise copy
979  if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
980  {
981  // contiguous in lds
985  }
987  {
988  // NLdsLayer * K0 as logical Bank
989  constexpr auto b_lds_block_desc =
992 
993  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
994  b_lds_block_desc,
1000 
1001  return b_lds_block_desc_permuted;
1002  }
1003  else // RowMajor B
1004  {
1005  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
1006  constexpr auto N1 = NPerBlock / N0;
1007 
1008  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
1009  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
1010  constexpr auto KThreadRead = WaveSize / NPerXdl;
1011  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
1012 
1013  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
1014  ? 1
1015  : 128 / (BK1Number * N0 * sizeof(BDataType));
1016  constexpr auto KThreadReadPerm =
1017  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1018  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1019  : KThreadRead;
1020 
1021  // 1<=npair<=n0
1022  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
1023  ? 1
1024  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
1025  ? N0
1026  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
1027 
1028  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
1032  Number<kfold * N0 / npair>{},
1033  Number<npair>{},
1034  BK1Number));
1035 
1036  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1037  b_lds_block_desc,
1038  make_tuple(
1042  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
1045  make_tuple(
1047  make_tuple(
1049 
1050  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
1051  b_lds_block_desc_permuted,
1052  make_tuple(
1060  Sequence<1>{},
1061  Sequence<2>{},
1062  Sequence<3>{},
1063  Sequence<4>{},
1064  Sequence<5>{}),
1066  Sequence<2>{},
1067  Sequence<0, 3>{},
1068  Sequence<4, 5>{},
1069  Sequence<6>{},
1070  Sequence<7>{}));
1071 
1072  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1073  b_lds_block_desc_unmerged,
1076  Number<KThreadWrite / kfold / KThreadReadPerm>{},
1077  Number<kfold>{},
1084 
1085  return b_lds_block_desc_bk0_n_bk1;
1086  }
1087  }
1088 
1090  {
1091  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1092  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1093 
1094  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1096  make_tuple(I1,
1098  I1,
1100 
1101  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1102  }
1103 
1106  BlkGemmPipelineVer,
1107  BlkGemmPipeSched,
1108  BlockSize,
1109  ScaleBlockSize,
1110  ADataType,
1111  AScaleDataType,
1112  BDataType,
1113  BScaleDataType,
1114  ComputeTypeA,
1115  AccDataType,
1122  ABlockTransferSrcScalarPerVector,
1123  BBlockTransferSrcScalarPerVector,
1124  MPerBlock,
1125  NPerBlock,
1126  KPerBlock,
1127  MPerXdl,
1128  NPerXdl,
1129  MXdlPerWave,
1130  NXdlPerWave,
1131  KPack,
1132  IsInputGemm>())>;
1133 
1134  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1135  {
1136  // LDS allocation for A and B: be careful of alignment
1137  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1138  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1139 
1140  // lds max alignment
1141  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1142 
1143  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1144  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1145 
1146  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1147  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1148 
1149  // LDS allocation for C shuffle in LDS
1150  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1152 
1153  constexpr auto c_block_size =
1154  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1155 
1156  if constexpr(IsInputGemm)
1157  {
1158  return math::max(a_block_space_size_aligned * sizeof(ADataType) +
1159  b_block_space_size_aligned * sizeof(BDataType) * 2,
1160  c_block_size * sizeof(CShuffleDataType));
1161  }
1162  else
1163  {
1164  return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1165  b_block_space_size_aligned * sizeof(BDataType)),
1166  c_block_size * sizeof(CShuffleDataType));
1167  }
1168  }
1169 
1170  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1171  __host__ static constexpr bool CheckValidity(const Argument& karg)
1172  {
1173  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1174  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1175  "Invalid tuning param!");
1176 
1177  static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1178  "KPerBlock should be multiple of ScaleBlockSize");
1179 
1180  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1185  {
1186  if(!(karg.M % MPerBlock == 0))
1187  {
1188  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1189  {
1190  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1191  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1192  << std::endl;
1193  }
1194  return false;
1195  }
1196  }
1197 
1198  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1203  {
1204  if(!(karg.N % NPerBlock == 0))
1205  {
1206  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1207  {
1208  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1209  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1210  << std::endl;
1211  }
1212  return false;
1213  }
1214  }
1215 
1216  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1220  {
1221  auto K_t = karg.KBatch * KPerBlock;
1222  if(!(karg.K % K_t == 0))
1223  {
1224  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1225  {
1226  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1227  << karg.K << " " << __FILE__ << ":" << __LINE__
1228  << ", in function: " << __func__ << std::endl;
1229  }
1230  return false;
1231  }
1232  }
1233  else
1234  {
1235  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1236  auto K_t = karg.KBatch * KReadVec;
1237  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1238  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1239  {
1240  return false;
1241  }
1242  }
1243 
1245  {
1246  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1247  {
1248  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1249  {
1250  std::cout << "Arg K (" << karg.K
1251  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1252  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1253  << __LINE__ << ", in function: " << __func__ << std::endl;
1254  }
1255  return false;
1256  }
1257  }
1258  else
1259  {
1260  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1261  {
1262  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1263  {
1264  std::cout << "Arg M (" << karg.M
1265  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1266  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1267  << __LINE__ << ", in function: " << __func__ << std::endl;
1268  }
1269  return false;
1270  }
1271  }
1272 
1274  {
1275  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1276  {
1277  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1278  {
1279  std::cout << "Arg N (" << karg.N
1280  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1281  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1282  << __LINE__ << ", in function: " << __func__ << std::endl;
1283  }
1284  return false;
1285  }
1286  }
1287  else
1288  {
1289  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1290  {
1291  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1292  {
1293  std::cout << "Arg K (" << karg.K
1294  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1295  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1296  << __LINE__ << ", in function: " << __func__ << std::endl;
1297  }
1298  return false;
1299  }
1300  }
1301 
1303  {
1305  {
1306  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1307  {
1308  std::cout << "Arg N (" << karg.N
1309  << ") value is not a multiple of "
1310  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1312  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1313  << std::endl;
1314  }
1315  return false;
1316  }
1317  }
1318  else
1319  {
1321  {
1322  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1323  {
1324  std::cout << "Arg M (" << karg.M
1325  << ") value is not a multiple of "
1326  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1328  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1329  << std::endl;
1330 
1331  return false;
1332  }
1333  }
1334  }
1335 
1336  // check gridwise gemm pipeline
1337 #if 0
1338  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1339 
1340  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1341  {
1342  return false;
1343  }
1344 #endif
1345  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1346  return true;
1347  }
1348 
1349  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1350  {
1351  const index_t num_loop = K / KPerBlock;
1352 
1353  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1354  }
1355 
1356  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1357  {
1358  const index_t num_loop = K / KPerBlock;
1359 
1360  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1361  }
1362 
1363  template <typename CGridDesc>
1364  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1365  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1366  {
1367  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1368  c_grid_desc_m_n,
1373 
1374  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1375  }
1376 
1377  // return block_id to C matrix tile idx (m0, n0) mapping
1378  // if arch = gfx942
1379  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1380  // NPerBlock>;
1381 
1383  static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
1384  static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
1385  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
1386  "A scale pack data type too large!");
1387  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
1388  "B scale pack data type too large!");
1389 
1390  static_assert(is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
1391  is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>,
1392  "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!");
1393 
1394 #if 0
1395  template <bool HasMainKBlockLoop,
1396  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1397  TailNumber TailNum = TailNumber::Odd>
1398  __device__ static void Run(const index_t* p_sorted_token_ids,
1399  const index_t* p_sorted_expert_ids,
1400  const index_t* p_max_token_id,
1401  const ADataType* p_a_grid,
1402  const AScaleDataType* p_a_scale_grid,
1403  const BDataType* p_b_grid,
1404  const BScaleDataType* p_b_scale_grid,
1405  DsGridPointer& p_ds_grid,
1406  CDataType* p_c_grid,
1407  void* p_shared,
1408  const Problem& problem,
1409  AElementwiseOperation a_element_op,
1410  BElementwiseOperation b_element_op,
1411  CElementwiseOperation c_element_op)
1412  {
1413  ignore = a_element_op;
1414  ignore = b_element_op;
1415  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1416  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1417  problem.MPadded,
1418  problem.K,
1419  problem.KPadded,
1420  problem.StrideA,
1421  problem.AK0);
1422  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1423  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1424  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1425  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1426  problem.MPadded,
1427  problem.N,
1428  problem.NPadded,
1429  problem.StrideC);
1430 
1431  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1432  make_tuple(problem.M / (MXdlPack * MPerXdl),
1433  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1434  (KXdlPack * 64 / MPerXdl),
1435  64 * KXdlPack * MXdlPack / scale_pack_size_a));
1436 
1437  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1438  make_tuple(problem.N / (NXdlPack * NPerXdl),
1439  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1440  (KXdlPack * 64 / NPerXdl),
1441  64 * KXdlPack * NXdlPack / scale_pack_size_b));
1442 
1443  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1445  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1446 
1447  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1448  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1449  if(expert_block_id * MPerBlock >= max_token_id)
1450  return;
1451  const index_t expert_id =
1452  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1453 
1454  const auto block_mn = [&]() -> std::pair<int, int> {
1455  if constexpr(NSwizzle)
1456  {
1457  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1458  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1459  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1460  const index_t expert_swizzle =
1461  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1462  const index_t bid_new = blockIdx.x - prefix_block;
1463  const index_t nid = __builtin_amdgcn_readfirstlane(
1464  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1465  const index_t mid =
1466  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1467  return {nid, mid};
1468  }
1469  else
1470  {
1471  return {blockIdx.x, blockIdx.y};
1472  }
1473  }();
1474 
1475  const index_t block_n_id = block_mn.first;
1476  const index_t block_m_id = block_mn.second;
1477  const index_t token0 =
1478  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1479 
1480  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1481  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1482  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1483  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1484  constexpr auto AKThreads = AK0Threads * AK1Threads;
1485  constexpr auto AMRepeats = MPerBlock / AMThreads;
1486  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1487 
1488  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1489  return;
1490  StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
1491  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1492  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1493  index_t token_offset = fused_token & 0xffffff;
1494  if constexpr(!IsInputGemm)
1495  {
1496  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1497  }
1498  gather_offsets(m0) = static_cast<IndexType>(token_offset);
1499  });
1500 
1501  const index_t expert_stride =
1502  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1503  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1504  problem.N * (IsInputGemm ? 2 : 1) *
1505  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
1506 
1507  // N0, K0, Blocksize*KPack
1508  const index_t n_block_data_idx_on_grid =
1509  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1510 
1511  // Gride buffer creation
1512  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1513  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1514  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1515  p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1516 
1517  // A, B scale buffer
1518  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1519  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1520  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1521  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
1522  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1523 
1524  // lds max alignment
1525  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1526 
1527  // A matrix in LDS memory, dst of blockwise copy
1528  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1529 
1530  // B matrix in LDS memory, dst of blockwise copy
1531  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1532 
1533  // A matrix blockwise direct to LDS copy
1534  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
1536  Sequence<AK0Number, MPerBlock, AK1Number>,
1537  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1538  ABlockTransferThreadClusterArrangeOrder,
1539  ADataType,
1540  ADataType,
1541  decltype(a_grid_desc_ak0_m_ak1),
1542  decltype(a_block_desc_ak0_m_ak1),
1543  ABlockTransferSrcAccessOrder,
1544  ABlockTransferSrcVectorDim,
1545  2,
1546  ABlockTransferSrcScalarPerVector,
1547  IndexType,
1548  1>(a_grid_desc_ak0_m_ak1,
1549  make_multi_index(0, 0, 0),
1550  a_block_desc_ak0_m_ak1,
1551  make_multi_index(0, 0, 0),
1552  gather_offsets);
1553 
1554  // B matrix blockwise copy
1555  auto b_blockwise_copy =
1556  ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
1557  Sequence<BK0Number, NPerBlock, BK1Number>,
1558  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1559  BBlockTransferThreadClusterArrangeOrder,
1560  BDataType,
1561  BDataType,
1562  decltype(b_grid_desc_bk0_n_bk1),
1563  decltype(b_block_desc_bk0_n_bk1),
1564  BBlockTransferSrcAccessOrder,
1565  BBlockTransferSrcVectorDim,
1566  2,
1567  BBlockTransferSrcScalarPerVector>(
1568  b_grid_desc_bk0_n_bk1,
1569  make_multi_index(0, n_block_data_idx_on_grid, 0),
1570  b_block_desc_bk0_n_bk1,
1571  make_multi_index(0, 0, 0));
1572 
1573  // LDS allocation for A and B: be careful of alignment
1574  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1575  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1576 
1577  // Cast after lds
1578  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1579  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1580 
1581  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1582  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1583  a_block_space_size_aligned * sizeof(ADataType)),
1584  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1585 
1586  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1587  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1588 
1589  // Blockwise GEMM pipeline
1590  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1591  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1592  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1593  decltype(c_thread_buf) c_thread_buf_up;
1594 
1595  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
1596  float,
1597  c_thread_buf.num_of_v_,
1598  c_thread_buf.s_per_v,
1599  true>
1600  c_thread_buf_fp32;
1601 
1602  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1603  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1604  KPerBlock);
1605 
1606  // a and b scale processing
1607  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1608  const auto waveId_m = wave_idx[I0];
1609  const auto waveId_n = wave_idx[I1];
1610 
1611  auto thread_offset_shuffled =
1612  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1613 
1614  auto a_thread_offset_m = waveId_m;
1615 
1616  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1617  AScaleDataType,
1618  AScaleDataType,
1619  decltype(a_scale_grid_desc_am_ak),
1620  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1621  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1622  Sequence<0, 1, 2>, // DimAccessOrder
1623  2, // SrcVectorDim
1624  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1625  1, // SrcScalarStrideInVector
1626  true>(a_scale_grid_desc_am_ak,
1627  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1628  0,
1629  thread_offset_shuffled / scale_pack_size_a));
1630 
1631  // B scale load
1632  auto b_thread_offset_n = waveId_n;
1633 
1634  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1635  BScaleDataType,
1636  BScaleDataType,
1637  decltype(b_scale_grid_desc_bn_ak),
1638  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1639  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1640  Sequence<0, 1, 2>, // DimAccessOrder
1641  2, // SrcVectorDim
1642  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1643  1, // SrcScalarStrideInVector
1644  true>(b_scale_grid_desc_bn_ak,
1645  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1646  0,
1647  thread_offset_shuffled / scale_pack_size_b));
1648 
1649  if constexpr(IsInputGemm)
1650  {
1651  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1652  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1653  auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1654  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1655  a_block_space_size_aligned * sizeof(ADataType) +
1656  b_block_space_size_aligned * sizeof(BDataType)),
1657  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1658 
1659  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1660  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1661  p_b_grid_up + expert_id * expert_stride,
1662  b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1663 
1664  auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
1666  Sequence<BK0Number, NPerBlock, BK1Number>,
1667  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1668  BBlockTransferThreadClusterArrangeOrder,
1669  BDataType,
1670  BDataType,
1671  decltype(b_grid_desc_bk0_n_bk1),
1672  decltype(b_block_desc_bk0_n_bk1),
1673  BBlockTransferSrcAccessOrder,
1674  BBlockTransferSrcVectorDim,
1675  2,
1676  BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
1677  make_multi_index(0, n_block_data_idx_on_grid, 0),
1678  b_block_desc_bk0_n_bk1,
1679  make_multi_index(0, 0, 0));
1680 
1681  const BScaleDataType* p_b_scale_grid_up =
1682  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
1683  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1684  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
1685  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1686 
1687  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1688  BScaleDataType,
1689  BScaleDataType,
1690  decltype(b_scale_grid_desc_bn_ak),
1691  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1692  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1693  Sequence<0, 1, 2>, // DimAccessOrder
1694  2, // SrcVectorDim
1695  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1696  1, // SrcScalarStrideInVector
1697  true>(
1698  b_scale_grid_desc_bn_ak,
1699  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1700  0,
1701  thread_offset_shuffled / scale_pack_size_b));
1702 
1703  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1704  // A
1705  a_grid_desc_ak0_m_ak1,
1706  a_block_desc_ak0_m_ak1,
1707  a_blockwise_copy,
1708  a_grid_buf,
1709  a_block_buf,
1710  a_block_slice_copy_step,
1711  // Gate and Up
1712  b_grid_desc_bk0_n_bk1,
1713  b_block_desc_bk0_n_bk1,
1714  b_blockwise_copy,
1715  b_blockwise_copy_up,
1716  b_grid_buf,
1717  b_grid_buf_up,
1718  b_block_buf,
1719  b_block_buf_up,
1720  b_block_slice_copy_step,
1721  // C
1722  c_thread_buf,
1723  c_thread_buf_up,
1724  // A scale
1725  a_scale_grid_desc_am_ak,
1726  a_scale_thread_copy,
1727  a_scale_grid_buf,
1728  // Gate and Up scale
1729  b_scale_grid_desc_bn_ak,
1730  b_scale_thread_copy,
1731  b_scale_thread_copy_up,
1732  b_scale_grid_buf,
1733  b_scale_grid_buf_up,
1734  num_k_block_main_loop);
1735  }
1736  else
1737  {
1738  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1739  a_grid_desc_ak0_m_ak1, // A
1740  a_block_desc_ak0_m_ak1,
1741  a_blockwise_copy,
1742  a_grid_buf,
1743  a_block_buf,
1744  a_block_slice_copy_step,
1745  b_grid_desc_bk0_n_bk1, // B
1746  b_block_desc_bk0_n_bk1,
1747  b_blockwise_copy,
1748  b_grid_buf,
1749  b_block_buf,
1750  b_block_slice_copy_step,
1751  c_thread_buf, // C
1752  a_scale_grid_desc_am_ak, // A scale
1753  a_scale_thread_copy,
1754  a_scale_grid_buf,
1755  b_scale_grid_desc_bn_ak, // B scale
1756  b_scale_thread_copy,
1757  b_scale_grid_buf,
1758  num_k_block_main_loop);
1759  }
1760 
1761  // shuffle C and write out
1762  {
1763  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1764  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1765  "wrong!");
1766  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1767  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1768  "wrong!");
1769 
1770  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1771  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1772 
1773  // TODO: hacky, fix it!
1774  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1775  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1776 
1777  // TODO: hacky, fix it!
1778  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1779  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1780  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1781 
1782  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1783  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1784  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1785  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1786  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1787  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1788  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1789  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1790  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1791  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1792 
1793  // mul scales
1794  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1795  static_assert(M5 == 4);
1796  const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id
1797  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
1798 
1799  vector_type<float, 4> topk_weights; // for gemm2 only
1800  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
1801  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
1802  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
1803  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
1804  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
1805  const index_t m_pos = block_m_id * MPerBlock +
1806  m0 * M2 * M1 * M3 * M4 * M5 +
1807  m1 * M2 * M3 * M4 * M5 +
1808  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1809  if constexpr(MulRoutedWeight)
1810  {
1811  topk_weights =
1812  *c_style_pointer_cast<const vector_type<float, M5>*>(
1813  p_ds_grid[I2] + m_pos);
1814  }
1815  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
1816  constexpr index_t c_offset =
1817  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1818  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1819  constexpr auto cidx = Number<c_offset>{};
1820 
1821  if constexpr(IsInputGemm) // gu fusion
1822  {
1823  if constexpr(ActivationOperation ==
1825  {
1826  float gate = c_thread_buf[cidx];
1827  float up = c_thread_buf_up[cidx];
1828  if constexpr(MulRoutedWeight)
1829  {
1830  gate = gate * topk_weights.AsType<float>()[m5];
1831  up = up * topk_weights.AsType<float>()[m5];
1832  }
1833  tensor_operation::element_wise::Silu{}(gate, gate);
1834  c_thread_buf_fp32(cidx) = gate * up;
1835  }
1836  else if(ActivationOperation == Activation::gelu_and_mul)
1837  {
1838  float gate = c_thread_buf[cidx];
1839  float up = c_thread_buf_up[cidx];
1840  if constexpr(MulRoutedWeight)
1841  {
1842  gate = gate * topk_weights.AsType<float>()[m5];
1843  up = up * topk_weights.AsType<float>()[m5];
1844  }
1845  tensor_operation::element_wise::Gelu{}(gate, gate);
1846  c_thread_buf_fp32(cidx) = gate * up;
1847 
1848  /*float gate = c_thread_buf[cidx];
1849  float up = c_thread_buf_up[cidx];
1850  if constexpr(MulRoutedWeight)
1851  {
1852  gate = gate * topk_weights.AsType<float>()[m5];
1853  //up = up * topk_weights.AsType<float>()[m5];
1854  }
1855  tensor_operation::element_wise::Gelu{}(gate, gate);
1856  c_thread_buf_fp32(cidx) = up;*/
1857  }
1858  }
1859  else
1860  {
1861  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1862  if constexpr(MulRoutedWeight)
1863  {
1864  c_thread_buf_fp32(cidx) =
1865  topk_weights.AsType<float>()[m5] *
1866  c_thread_buf_fp32[cidx];
1867  }
1868  }
1869  });
1870  });
1871  });
1872  });
1873  });
1874  });
1875 
1876  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1878 
1879  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1880  static_cast<CShuffleDataType*>(p_shared),
1881  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1882 
1883  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1884  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1885  make_tuple(
1888  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave)
1889  // per shuffle
1890  M1, // M1 = MWave
1891  M2, // M2 = MXdlPack
1892  M3, // M3 * M4 * M5 = MPerXdl
1893  M4,
1894  M5)),
1897  Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave)
1898  // per shuffle
1899  N1, // N1 = NWave
1900  N2, // N2 = NXdlPack
1901  N3))), // N3 = NPerXdl
1902  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
1903  make_tuple(Sequence<>{},
1904  Sequence<0, 2, 4, 6, 7, 8>{},
1905  Sequence<>{},
1906  Sequence<1, 3, 5, 9>{}));
1907 
1908  // calculate origin of thread output tensor on global memory
1909  // blockwise GEMM c matrix starting index
1910  const auto c_thread_mtx_on_block =
1911  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1912 
1913  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1914  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1915 
1916  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1918  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1919  make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}),
1920  make_tuple(Sequence<0>{}));
1921 
1922  const auto m_thread_data_on_block_idx =
1923  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1924  make_multi_index(m_thread_data_on_block));
1925 
1926  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1928  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1929  make_tuple(Sequence<0, 1, 2, 3>{}),
1930  make_tuple(Sequence<0>{}));
1931 
1932  const auto n_thread_data_on_block_idx =
1933  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1934  make_multi_index(n_thread_data_on_block));
1935 
1936  // shuffle: threadwise copy C from VGPR to LDS
1937  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1938  AccDataType,
1939  CShuffleDataType,
1940  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1941  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1943  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1944  CShuffleNXdlPerWavePerShuffle / NXdlPack,
1945  I1,
1946  I1,
1947  M2,
1948  N2,
1949  M3,
1950  I1,
1951  M5,
1952  I1>,
1953  Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1954  9,
1955  1,
1957  1,
1958  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1959  make_multi_index(0,
1960  0,
1961  m_thread_data_on_block_idx[I1],
1962  n_thread_data_on_block_idx[I1],
1963  m_thread_data_on_block_idx[I2],
1964  n_thread_data_on_block_idx[I2],
1965  m_thread_data_on_block_idx[I3],
1966  m_thread_data_on_block_idx[I4],
1967  m_thread_data_on_block_idx[I5],
1968  n_thread_data_on_block_idx[I3]),
1970 
1971  using EDataType = CDataType;
1972 
1973  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1974  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1975 
1976  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1978  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1979 
1980  const auto ds_grid_buf = generate_tuple(
1981  [&](auto i) {
1982  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1983  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1984  },
1985  Number<NumDTensor>{});
1986 
1987  // tuple of reference to C/Ds tensor descriptors
1988  const auto c_ds_desc_refs = concat_tuple_of_reference(
1989  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1990  generate_tie([&](auto i) -> const auto& // return type should be reference
1991  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1992  Number<NumDTensor>{}));
1993 
1994  // tuple of reference to C/Ds tensor descriptors
1995  const auto c_ds_buf_refs = concat_tuple_of_reference(
1996  tie(c_shuffle_block_buf),
1997  generate_tie([&](auto i) -> const auto& // return type should be reference
1998  { return ds_grid_buf[i]; },
1999  Number<NumDTensor>{}));
2000 
2001  // tuple of starting index of C/Ds blockwise copy
2002  const auto idx_c_ds_block_begin =
2005  [&](auto) {
2006  return make_multi_index(block_m_id, 0, block_n_id, 0);
2007  // return make_multi_index(block_work_idx[I0], 0,
2008  // block_work_idx[I1], 0);
2009  },
2010  Number<NumDTensor>{}));
2011 
2012  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2013  c_grid_desc_mblock_mperblock_nblock_nperblock;
2014 
2015  using CDEBlockTransferCluster =
2016  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2017  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2018  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2019  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2021  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2022  Tuple<EDataType>,
2023  decltype(c_ds_desc_refs),
2024  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2025  CElementwiseOperation,
2026  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2027  // Sequence support
2028  // arbitray type
2029  Sequence<1,
2030  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2031  1,
2032  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2033  CDEBlockTransferCluster,
2034  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2035  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2036  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2037  3, // index_t SrcVectorDim,
2038  3, // index_t DstVectorDim,
2039  CDEShuffleBlockTransferScalarPerVectors,
2042  Sequence<true>,
2044  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2045  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2046  IndexType,
2047  1, // ScatterDim
2048  true, // OutputScatter: false, only use scatter weights
2049  scatter_weight_idx // ScatterWeightIdx: ascale
2050  >{c_ds_desc_refs,
2051  idx_c_ds_block_begin,
2052  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2053  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2054  c_element_op};
2055 
2056  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2057  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2058 
2059  constexpr auto sfc_c_vgpr =
2060  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2061  NXdlPerWave / NXdlPack,
2062  1,
2063  1,
2064  MXdlPack,
2065  NXdlPack,
2066  M2,
2067  1,
2068  M4,
2069  1>,
2070  Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2071  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2072  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2073  1,
2074  1,
2075  MXdlPack,
2076  NXdlPack,
2077  M2,
2078  1,
2079  M4,
2080  1>>{};
2081 
2082  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2083 
2084  // space filling curve for shuffled blockwise C/D/E
2085  constexpr auto sfc_cde_block =
2086  SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
2087  Sequence<0, 2, 1, 3>,
2088  Sequence<1,
2089  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2090  1,
2091  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2092 
2093  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2094  constexpr auto EMThreads =
2095  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2096  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2097  constexpr auto ENThreads =
2098  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2099  static_for<0, num_access, 1>{}([&](auto access_id) {
2100  // make sure it's safe to write to LDS
2101  StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
2102 
2103  auto dstidx = sfc_cde_block.GetIndex(access_id);
2104  const index_t c_token_pos =
2105  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2106  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2107  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2108  IndexType token_offset = fused_token & 0xffffff;
2109  if constexpr(IsInputGemm)
2110  {
2111  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2112  }
2113  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2114  });
2115 
2116  block_sync_lds();
2117 
2118  // each thread write its data from VGPR to LDS
2119  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2120  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2121  c_thread_buf_fp32,
2122  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2123  c_shuffle_block_buf);
2124 
2125  // make sure it's safe to read from LDS
2126  block_sync_lds();
2127 
2128  // each block copy its data from LDS to global
2129  cde_block_copy_lds_and_global.Run(
2130  c_ds_desc_refs,
2131  c_ds_buf_refs,
2132  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2133  tie(c_grid_buf),
2134  scatter_offsets);
2135 
2136  if constexpr(access_id < num_access - 1)
2137  {
2138  constexpr auto cde_lds_and_global_step =
2139  sfc_cde_block.GetForwardStep(access_id);
2140 
2141  // move on Ds
2142  static_for<0, NumDTensor, 1>{}([&](auto i) {
2143  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2144  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2145  });
2146 
2147  // move on E
2148  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2149  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2150  I0,
2151  cde_lds_and_global_step);
2152  }
2153  });
2154  }
2155  }
2156 #endif
2157 
2158  template <bool HasMainKBlockLoop,
2159  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2160  TailNumber TailNum = TailNumber::Odd>
2161  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
2162  const index_t* p_sorted_expert_ids,
2163  const index_t* p_max_token_id,
2164  const ADataType* p_a_grid,
2165  const AScaleDataType* p_a_scale_grid,
2166  const BDataType* p_b_grid,
2167  const BScaleDataType* p_b_scale_grid,
2168  DsGridPointer& p_ds_grid,
2169  CDataType* p_c_grid,
2170  void* p_shared_0,
2171  void* p_shared_1,
2172  const Problem& problem,
2173  AElementwiseOperation a_element_op,
2174  BElementwiseOperation b_element_op,
2175  CElementwiseOperation c_element_op)
2176  {
2177  ignore = a_element_op;
2178  ignore = b_element_op;
2179  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2180  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2181  problem.MPadded,
2182  problem.K,
2183  problem.KPadded,
2184  problem.StrideA,
2185  problem.AK0);
2186  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2187  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2188  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2189  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2190  problem.MPadded,
2191  problem.N,
2192  problem.NPadded,
2193  problem.StrideC);
2194 
2195  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
2196  make_tuple(problem.M / (MXdlPack * MPerXdl),
2197  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2198  (KXdlPack * 64 / MPerXdl),
2199  64 * KXdlPack * MXdlPack / scale_pack_size_a));
2200 
2201  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
2202  make_tuple(problem.N / (NXdlPack * NPerXdl),
2203  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2204  (KXdlPack * 64 / NPerXdl),
2205  64 * KXdlPack * NXdlPack / scale_pack_size_b));
2206 
2207  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2209  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2210 
2211  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2212  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2213  if(expert_block_id * MPerBlock >= max_token_id)
2214  return;
2215  const index_t expert_id =
2216  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2217  const auto block_mn = [&]() -> std::pair<int, int> {
2218  if constexpr(NSwizzle)
2219  {
2220  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2221  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2222  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2223  const index_t expert_swizzle =
2224  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2225  const index_t bid_new = blockIdx.x - prefix_block;
2226  const index_t nid = __builtin_amdgcn_readfirstlane(
2227  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2228  const index_t mid =
2229  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2230  return {nid, mid};
2231  }
2232  else
2233  {
2234  return {blockIdx.x, blockIdx.y};
2235  }
2236  }();
2237 
2238  const index_t block_n_id = block_mn.first;
2239  const index_t block_m_id = block_mn.second;
2240  const index_t token0 =
2241  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2242 
2243  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2244  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2245  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2246  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2247  constexpr auto AKThreads = AK0Threads * AK1Threads;
2248  constexpr auto AMRepeats = MPerBlock / AMThreads;
2249  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2250 
2251  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2252  return;
2254  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2255  const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2256  index_t token_offset = fused_token & 0xffffff;
2257  if constexpr(!IsInputGemm)
2258  {
2259  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2260  }
2261  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2262  });
2263 
2264  const index_t expert_stride =
2265  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2266  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2267  problem.N * (IsInputGemm ? 2 : 1) *
2268  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2269 
2270  // N0, K0, Blocksize*KPack
2271  const index_t n_block_data_idx_on_grid =
2272  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
2273 
2274  // Gride buffer creation
2275  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2276  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2277  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2278  p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2279 
2280  // A, B scale buffer
2281  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2282  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2283  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2284  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2285  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2286 
2287  // lds max alignment
2288  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
2289 
2290  // A matrix in LDS memory, dst of blockwise copy
2291  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2292 
2293  // B matrix in LDS memory, dst of blockwise copy
2294  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2295 
2296  // A matrix blockwise direct to LDS copy
2297  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
2300  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2301  ABlockTransferThreadClusterArrangeOrder,
2302  ADataType,
2303  ADataType,
2304  decltype(a_grid_desc_ak0_m_ak1),
2305  decltype(a_block_desc_ak0_m_ak1),
2306  ABlockTransferSrcAccessOrder,
2307  ABlockTransferSrcVectorDim,
2308  2,
2309  ABlockTransferSrcScalarPerVector,
2310  IndexType,
2311  1>(a_grid_desc_ak0_m_ak1,
2312  make_multi_index(0, 0, 0),
2313  a_block_desc_ak0_m_ak1,
2314  make_multi_index(0, 0, 0),
2315  gather_offsets);
2316 
2317  // B matrix blockwise copy
2318  auto b_blockwise_copy =
2321  BBlockTransferThreadClusterLengths_BK0_N_BK1,
2322  BBlockTransferThreadClusterArrangeOrder,
2323  BDataType,
2324  BDataType,
2325  decltype(b_grid_desc_bk0_n_bk1),
2326  decltype(b_block_desc_bk0_n_bk1),
2327  BBlockTransferSrcAccessOrder,
2328  BBlockTransferSrcVectorDim,
2329  2,
2330  BBlockTransferSrcScalarPerVector>(
2331  b_grid_desc_bk0_n_bk1,
2332  make_multi_index(0, n_block_data_idx_on_grid, 0),
2333  b_block_desc_bk0_n_bk1,
2334  make_multi_index(0, 0, 0));
2335 
2336  // LDS allocation for A and B: be careful of alignment
2337  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
2338  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2339 
2340  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2341  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2342 
2343  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2344  bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
2345  a_block_space_size_aligned * sizeof(ADataType)),
2346  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2347 
2348  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2349  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2350 
2351  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2352  bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2353  a_block_space_size_aligned * sizeof(ADataType)),
2354  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2355 
2356  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2357  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2358 
2359  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2360  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
2361 
2362  // Blockwise GEMM pipeline
2363  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2364  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2365  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2366  decltype(c_thread_buf) c_thread_buf_up;
2367 
2369  float,
2370  c_thread_buf.num_of_v_,
2371  c_thread_buf.s_per_v,
2372  true>
2373  c_thread_buf_fp32;
2374 
2375  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2376  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2377  KPerBlock);
2378 
2379  // a and b scale processing
2380  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2381  const auto waveId_m = wave_idx[I0];
2382  const auto waveId_n = wave_idx[I1];
2383 
2384  auto thread_offset_shuffled =
2385  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2386 
2387  auto a_thread_offset_m = waveId_m;
2388 
2389  // get each thread's offset int the scale tensor
2390  const index_t token_scale_pos = block_m_id * MPerBlock;
2391  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2392  return;
2393 
2394  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2395  AScaleDataType,
2396  AScaleDataType,
2397  decltype(a_scale_grid_desc_am_ak),
2398  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2399  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2400  Sequence<0, 1, 2>, // DimAccessOrder
2401  2, // SrcVectorDim
2402  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2403  1, // SrcScalarStrideInVector
2404  true>(a_scale_grid_desc_am_ak,
2405  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2406  0,
2407  thread_offset_shuffled / scale_pack_size_a));
2408 
2409  // B scale load
2410  auto b_thread_offset_n = waveId_n;
2411 
2412  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2413  BScaleDataType,
2414  BScaleDataType,
2415  decltype(b_scale_grid_desc_bn_ak),
2416  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2417  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2418  Sequence<0, 1, 2>, // DimAccessOrder
2419  2, // SrcVectorDim
2420  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2421  1, // SrcScalarStrideInVector
2422  true>(b_scale_grid_desc_bn_ak,
2423  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2424  0,
2425  thread_offset_shuffled / scale_pack_size_b));
2426 
2427  if constexpr(IsInputGemm)
2428  {
2429  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2430  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2431  p_b_grid_up + expert_id * expert_stride,
2432  b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2433 
2434  // lds ping pong buffers for up
2435  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
2436  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
2437  auto b_block_buf_up_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2438  bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
2439  a_block_space_size_aligned * sizeof(ADataType) +
2440  b_block_space_size_aligned * sizeof(BDataType)),
2441  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2442  auto b_block_buf_up_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2443  bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2444  a_block_space_size_aligned * sizeof(ADataType) +
2445  b_block_space_size_aligned * sizeof(BDataType)),
2446  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2447 
2448  auto b_block_bufs_up = make_tuple(b_block_buf_up_ping, b_block_buf_up_pong);
2449 
2450  auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
2453  BBlockTransferThreadClusterLengths_BK0_N_BK1,
2454  BBlockTransferThreadClusterArrangeOrder,
2455  BDataType,
2456  BDataType,
2457  decltype(b_grid_desc_bk0_n_bk1),
2458  decltype(b_block_desc_bk0_n_bk1),
2459  BBlockTransferSrcAccessOrder,
2460  BBlockTransferSrcVectorDim,
2461  2,
2462  BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
2463  make_multi_index(0, n_block_data_idx_on_grid, 0),
2464  b_block_desc_bk0_n_bk1,
2465  make_multi_index(0, 0, 0));
2466 
2467  const BScaleDataType* p_b_scale_grid_up =
2468  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
2469  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2470  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
2471  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2472 
2473  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2474  BScaleDataType,
2475  BScaleDataType,
2476  decltype(b_scale_grid_desc_bn_ak),
2477  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2478  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2479  Sequence<0, 1, 2>, // DimAccessOrder
2480  2, // SrcVectorDim
2481  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2482  1, // SrcScalarStrideInVector
2483  true>(
2484  b_scale_grid_desc_bn_ak,
2485  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2486  0,
2487  thread_offset_shuffled / scale_pack_size_b));
2488 
2489  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2490  // A
2491  a_grid_desc_ak0_m_ak1,
2492  a_block_desc_ak0_m_ak1,
2493  a_blockwise_copy,
2494  a_grid_buf,
2495  a_block_bufs,
2496  a_block_slice_copy_step,
2497  // Gate and Up
2498  b_grid_desc_bk0_n_bk1,
2499  b_block_desc_bk0_n_bk1,
2500  b_blockwise_copy,
2501  b_blockwise_copy_up,
2502  b_grid_buf,
2503  b_grid_buf_up,
2504  b_block_bufs,
2505  b_block_bufs_up,
2506  b_block_slice_copy_step,
2507  // C
2508  c_thread_buf,
2509  c_thread_buf_up,
2510  // A scale
2511  a_scale_grid_desc_am_ak,
2512  a_scale_thread_copy,
2513  a_scale_grid_buf,
2514  // B scale
2515  b_scale_grid_desc_bn_ak,
2516  b_scale_thread_copy,
2517  b_scale_thread_copy_up,
2518  b_scale_grid_buf,
2519  b_scale_grid_buf_up,
2520  num_k_block_main_loop);
2521  }
2522  else
2523  {
2524  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2525  a_grid_desc_ak0_m_ak1, // A
2526  a_block_desc_ak0_m_ak1,
2527  a_blockwise_copy,
2528  a_grid_buf,
2529  a_block_bufs,
2530  a_block_slice_copy_step,
2531  b_grid_desc_bk0_n_bk1, // B
2532  b_block_desc_bk0_n_bk1,
2533  b_blockwise_copy,
2534  b_grid_buf,
2535  b_block_bufs,
2536  b_block_slice_copy_step,
2537  c_thread_buf, // C
2538  a_scale_grid_desc_am_ak, // A scale
2539  a_scale_thread_copy,
2540  a_scale_grid_buf,
2541  b_scale_grid_desc_bn_ak, // B scale
2542  b_scale_thread_copy,
2543  b_scale_grid_buf,
2544  num_k_block_main_loop);
2545  }
2546 
2547  // shuffle C and write out
2548  {
2549  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2550  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2551  "wrong!");
2552  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
2553  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
2554  "wrong!");
2555 
2556  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2557  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2558 
2559  // TODO: hacky, fix it!
2560  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2561  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2562 
2563  // TODO: hacky, fix it!
2564  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2565  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2566  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2567 
2568  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2569  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2570  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2571  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2572  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2573  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2574  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2575  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2576  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
2577  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
2578 
2579  // mul scales
2580 
2581  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2582  static_assert(M5 == 4);
2583  const index_t m1 = get_warp_local_1d_id() / NWave;
2584  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
2585 
2586  vector_type<float, 4> topk_weights; // for gemm2 only
2587  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
2588  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
2589  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
2590  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
2591  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
2592  const index_t m_pos = block_m_id * MPerBlock +
2593  m0 * M2 * M1 * M3 * M4 * M5 +
2594  m1 * M2 * M3 * M4 * M5 +
2595  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2596  if constexpr(MulRoutedWeight)
2597  {
2598  topk_weights =
2599  *c_style_pointer_cast<const vector_type<float, M5>*>(
2600  p_ds_grid[I2] + m_pos);
2601  }
2602  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
2603  constexpr index_t c_offset =
2604  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2605  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2606  constexpr auto cidx = Number<c_offset>{};
2607 
2608  if constexpr(IsInputGemm) // gu fusion
2609  {
2610  if constexpr(ActivationOperation ==
2612  {
2613  float gate = c_thread_buf[cidx];
2614  float up = c_thread_buf_up[cidx];
2615  if constexpr(MulRoutedWeight)
2616  {
2617  gate = gate * topk_weights.AsType<float>()[m5];
2618  up = up * topk_weights.AsType<float>()[m5];
2619  }
2621  c_thread_buf_fp32(cidx) = gate * up;
2622  }
2623  else if(ActivationOperation == Activation::gelu_and_mul)
2624  {
2625  float gate = c_thread_buf[cidx];
2626  float up = c_thread_buf_up[cidx];
2627  if constexpr(MulRoutedWeight)
2628  {
2629  gate = gate * topk_weights.AsType<float>()[m5];
2630  up = up * topk_weights.AsType<float>()[m5];
2631  }
2633  c_thread_buf_fp32(cidx) = gate * up;
2634  }
2635  }
2636  else
2637  {
2638  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2639  if constexpr(MulRoutedWeight)
2640  {
2641  c_thread_buf_fp32(cidx) =
2642  topk_weights.AsType<float>()[m5] *
2643  c_thread_buf_fp32[cidx];
2644  }
2645  }
2646  });
2647  });
2648  });
2649  });
2650  });
2651  });
2652 
2653  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2655 
2656  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2657  static_cast<CShuffleDataType*>(p_shared_0),
2658  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2659 
2660  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2661  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2662  make_tuple(
2665  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
2666  // shuffle
2667  M1, // M1 = MWave
2668  M2, // M2 * M3 * M4 = MPerXdl
2669  M3,
2670  M4,
2671  M5)),
2675  // per shuffle
2676  N1, // N1 = NWave
2677  N2, // N2 = NXdlPack
2678  N3))), // N3 = NPerXdl
2682  Sequence<>{},
2684 
2685  // calculate origin of thread output tensor on global memory
2686  // blockwise GEMM c matrix starting index
2687  const auto c_thread_mtx_on_block =
2688  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2689 
2690  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2691  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2692 
2693  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2695  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
2697  make_tuple(Sequence<0>{}));
2698 
2699  const auto m_thread_data_on_block_idx =
2700  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2701  make_multi_index(m_thread_data_on_block));
2702 
2703  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2705  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
2707  make_tuple(Sequence<0>{}));
2708 
2709  const auto n_thread_data_on_block_idx =
2710  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2711  make_multi_index(n_thread_data_on_block));
2712 
2713  // shuffle: threadwise copy C from VGPR to LDS
2714  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2715  AccDataType,
2716  CShuffleDataType,
2717  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2718  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2720  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2721  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2722  I1,
2723  I1,
2724  M2,
2725  N2,
2726  M3,
2727  I1,
2728  M5,
2729  I1>,
2731  9,
2732  1,
2734  1,
2735  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2736  make_multi_index(0,
2737  0,
2738  m_thread_data_on_block_idx[I1],
2739  n_thread_data_on_block_idx[I1],
2740  m_thread_data_on_block_idx[I2],
2741  n_thread_data_on_block_idx[I2],
2742  m_thread_data_on_block_idx[I3],
2743  m_thread_data_on_block_idx[I4],
2744  m_thread_data_on_block_idx[I5],
2745  n_thread_data_on_block_idx[I3]),
2747 
2748  using EDataType = CDataType;
2749 
2750  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2751  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2752 
2753  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2755  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2756 
2757  const auto ds_grid_buf = generate_tuple(
2758  [&](auto i) {
2759  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2760  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2761  },
2762  Number<NumDTensor>{});
2763 
2764  // tuple of reference to C/Ds tensor descriptors
2765  const auto c_ds_desc_refs = concat_tuple_of_reference(
2766  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2767  generate_tie([&](auto i) -> const auto& // return type should be reference
2768  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2769  Number<NumDTensor>{}));
2770 
2771  // tuple of reference to C/Ds tensor descriptors
2772  const auto c_ds_buf_refs = concat_tuple_of_reference(
2773  tie(c_shuffle_block_buf),
2774  generate_tie([&](auto i) -> const auto& // return type should be reference
2775  { return ds_grid_buf[i]; },
2776  Number<NumDTensor>{}));
2777 
2778  // tuple of starting index of C/Ds blockwise copy
2779  const auto idx_c_ds_block_begin =
2782  [&](auto) {
2783  return make_multi_index(block_m_id, 0, block_n_id, 0);
2784  // return make_multi_index(block_work_idx[I0], 0,
2785  // block_work_idx[I1], 0);
2786  },
2787  Number<NumDTensor>{}));
2788 
2789  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2790  c_grid_desc_mblock_mperblock_nblock_nperblock;
2791 
2792  using CDEBlockTransferCluster =
2793  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2794  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2795  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2796  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2798  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2800  decltype(c_ds_desc_refs),
2801  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2802  CElementwiseOperation,
2803  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2804  // Sequence support
2805  // arbitray type
2806  Sequence<1,
2807  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2808  1,
2809  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2810  CDEBlockTransferCluster,
2811  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2812  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2813  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2814  3, // index_t SrcVectorDim,
2815  3, // index_t DstVectorDim,
2816  CDEShuffleBlockTransferScalarPerVectors,
2821  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2822  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2823  IndexType,
2824  1, // ScatterDim
2825  true, // OutputScatter: false, only use scatter weights
2826  scatter_weight_idx // ScatterWeightIdx: ascale
2827  >{c_ds_desc_refs,
2828  idx_c_ds_block_begin,
2829  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2830  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2831  c_element_op};
2832 
2833  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2834  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2835 
2836  constexpr auto sfc_c_vgpr =
2837  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2838  NXdlPerWave / NXdlPack,
2839  1,
2840  1,
2841  MXdlPack,
2842  NXdlPack,
2843  M2,
2844  1,
2845  M4,
2846  1>,
2848  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2849  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2850  1,
2851  1,
2852  MXdlPack,
2853  NXdlPack,
2854  M2,
2855  1,
2856  M4,
2857  1>>{};
2858 
2859  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2860 
2861  // space filling curve for shuffled blockwise C/D/E
2862  constexpr auto sfc_cde_block =
2865  Sequence<1,
2866  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2867  1,
2868  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2869 
2870  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2871  constexpr auto EMThreads =
2872  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2873  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2874  constexpr auto ENThreads =
2875  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2876  static_for<0, num_access, 1>{}([&](auto access_id) {
2877  // make sure it's safe to write to LDS
2879 
2880  auto dstidx = sfc_cde_block.GetIndex(access_id);
2881  const index_t c_token_pos =
2882  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2883  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2884  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2885  IndexType token_offset = fused_token & 0xffffff;
2886  if constexpr(IsInputGemm)
2887  {
2888  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2889  }
2890  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2891  });
2892 
2893  block_sync_lds();
2894 
2895  // each thread write its data from VGPR to LDS
2896  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2897  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2898  c_thread_buf_fp32,
2899  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2900  c_shuffle_block_buf);
2901 
2902  // make sure it's safe to read from LDS
2903  block_sync_lds();
2904 
2905  // each block copy its data from LDS to global
2906  cde_block_copy_lds_and_global.Run(
2907  c_ds_desc_refs,
2908  c_ds_buf_refs,
2909  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2910  tie(c_grid_buf),
2911  scatter_offsets);
2912 
2913  if constexpr(access_id < num_access - 1)
2914  {
2915  constexpr auto cde_lds_and_global_step =
2916  sfc_cde_block.GetForwardStep(access_id);
2917 
2918  // move on Ds
2919  static_for<0, NumDTensor, 1>{}([&](auto i) {
2920  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2921  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2922  });
2923 
2924  // move on E
2925  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2926  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2927  I0,
2928  cde_lds_and_global_step);
2929  }
2930  });
2931  }
2932  }
2933 };
2934 
2935 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:56
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
constexpr auto BlockGemmMXPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp:36
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: gridwise_moe_mx_gemm.hpp:715
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm.hpp:777
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm.hpp:783
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm.hpp:785
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm.hpp:776
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm.hpp:787
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm.hpp:780
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:781
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm.hpp:782
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm.hpp:716
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm.hpp:775
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm.hpp:786
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:779
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm.hpp:778
Definition: gridwise_moe_mx_gemm.hpp:643
index_t MBlock
Definition: gridwise_moe_mx_gemm.hpp:709
index_t NPadded
Definition: gridwise_moe_mx_gemm.hpp:704
index_t K
Definition: gridwise_moe_mx_gemm.hpp:695
index_t N
Definition: gridwise_moe_mx_gemm.hpp:694
index_t NumTokens
Definition: gridwise_moe_mx_gemm.hpp:691
index_t M
Definition: gridwise_moe_mx_gemm.hpp:693
index_t StrideA
Definition: gridwise_moe_mx_gemm.hpp:696
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm.hpp:699
index_t KRead
Definition: gridwise_moe_mx_gemm.hpp:705
index_t NBlock
Definition: gridwise_moe_mx_gemm.hpp:710
index_t StrideC
Definition: gridwise_moe_mx_gemm.hpp:701
index_t StrideB
Definition: gridwise_moe_mx_gemm.hpp:698
__host__ void Print() const
Definition: gridwise_moe_mx_gemm.hpp:679
index_t BK0
Definition: gridwise_moe_mx_gemm.hpp:708
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm.hpp:697
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm.hpp:700
index_t MPadded
Definition: gridwise_moe_mx_gemm.hpp:703
index_t KBatch
Definition: gridwise_moe_mx_gemm.hpp:702
index_t KPadded
Definition: gridwise_moe_mx_gemm.hpp:706
index_t TopK
Definition: gridwise_moe_mx_gemm.hpp:692
index_t AK0
Definition: gridwise_moe_mx_gemm.hpp:707
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm.hpp:644
Definition: gridwise_moe_mx_gemm.hpp:791
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm.hpp:792
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:846
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:848
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:845
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:847
Definition: gridwise_moe_mx_gemm.hpp:173
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm.hpp:242
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm.hpp:212
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm.hpp:227
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm.hpp:1171
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm.hpp:619
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:631
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:273
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm.hpp:191
static constexpr auto I7
Definition: gridwise_moe_mx_gemm.hpp:184
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm.hpp:213
remove_cvref_t< decltype(BlockGemmMXPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm.hpp:1132
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:252
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:279
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm.hpp:204
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm.hpp:972
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm.hpp:331
static constexpr auto I6
Definition: gridwise_moe_mx_gemm.hpp:183
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1356
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm.hpp:193
static constexpr auto I9
Definition: gridwise_moe_mx_gemm.hpp:186
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:285
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm.hpp:238
static constexpr auto I8
Definition: gridwise_moe_mx_gemm.hpp:185
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm.hpp:446
static constexpr auto I0
Definition: gridwise_moe_mx_gemm.hpp:177
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm.hpp:175
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm.hpp:555
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm.hpp:203
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm.hpp:574
static constexpr auto I3
Definition: gridwise_moe_mx_gemm.hpp:180
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm.hpp:1384
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm.hpp:1383
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:292
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:262
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm.hpp:200
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm.hpp:225
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:297
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm.hpp:851
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm.hpp:188
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm.hpp:2161
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm.hpp:174
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm.hpp:565
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm.hpp:240
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm.hpp:196
static constexpr auto I1
Definition: gridwise_moe_mx_gemm.hpp:178
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm.hpp:194
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm.hpp:198
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm.hpp:598
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:257
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm.hpp:1134
static constexpr auto I5
Definition: gridwise_moe_mx_gemm.hpp:182
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm.hpp:221
static constexpr auto I4
Definition: gridwise_moe_mx_gemm.hpp:181
static constexpr auto I2
Definition: gridwise_moe_mx_gemm.hpp:179
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm.hpp:307
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm.hpp:202
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm.hpp:197
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm.hpp:1089
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:1364
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:267
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm.hpp:192
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1349
Definition: xdlops_gemm.hpp:1126
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1647
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Definition: thread_group_tensor_slice_transfer_direct_load.hpp:55
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: data_type.hpp:41
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129