/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File
gridwise_moe_mx_gemm_bpreshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
16 
20 
21 #define DEBUG_LOG 0
22 
23 namespace ck {
24 
25 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26 // kernel function Blockers:
27 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28 // two lds chunks.
29 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30 // buffer when we declare __shared__ inside blkgemmpipe
31 
33 {
34  gelu_and_mul = 0,
35  silu_and_mul = 1
36 };
37 
38 template <typename GridwiseGemm,
39  bool HasMainKBlockLoop,
40  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
41  index_t MinimumOccupancy = 1,
42  TailNumber TailNum = TailNumber::Even>
43 __global__ void
44 #if CK_USE_LAUNCH_BOUNDS
45 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
46 #endif
47  // __attribute__((amdgpu_waves_per_eu(1, 1)))
48  kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
49 {
50 #if defined(__gfx9__)
51  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52 
53  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54 
55  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56  karg.p_sorted_token_ids,
57  karg.p_sorted_expert_ids,
58  karg.p_max_token_id,
59  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60  karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
61  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
62  karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
63  karg.p_ds_grid,
64  karg.p_c_grid,
65  p_shared,
66  karg,
67  karg.a_element_op,
68  karg.b_element_op,
69  karg.c_element_op);
70 #else
71  ignore = karg;
72 #endif // end of if (defined(__gfx9__))
73 }
74 
75 template <typename GridwiseGemm,
76  bool HasMainKBlockLoop,
77  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
78  index_t MinimumOccupancy = 1,
79  TailNumber TailNum = TailNumber::Even>
80 __global__ void
81 #if CK_USE_LAUNCH_BOUNDS
82 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
83 #endif
84  // __attribute__((amdgpu_waves_per_eu(1, 1)))
85  kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
86 {
87 #if defined(__gfx9__)
88  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
89  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
90 
91  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
92 
93  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
94  karg.p_sorted_token_ids,
95  karg.p_sorted_expert_ids,
96  karg.p_max_token_id,
97  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
98  karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
99  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
100  karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
101  karg.p_ds_grid,
102  karg.p_c_grid,
103  p_shared_0,
104  p_shared_1,
105  karg,
106  karg.a_element_op,
107  karg.b_element_op,
108  karg.c_element_op);
109 #else
110  ignore = karg;
111 #endif // end of if (defined(__gfx9__))
112 }
113 
114 template <typename ALayout,
115  typename BLayout,
116  typename DsLayout,
117  typename CLayout,
118  typename ADataType,
119  typename AScaleDataType,
120  typename BDataType,
121  typename BScaleDataType,
122  typename AccDataType,
123  typename CShuffleDataType,
124  typename DsDataType,
125  typename CDataType,
126  typename AElementwiseOperation,
127  typename BElementwiseOperation,
128  typename CElementwiseOperation,
130  index_t ScaleBlockSize, // Scaling block size
131  index_t BlockSize, // Thread block size
132  index_t MPerBlock,
133  index_t NPerBlock,
134  index_t KPerBlock,
135  index_t AK1Value,
136  index_t BK1Value,
137  index_t MPerXdl,
138  index_t NPerXdl,
139  index_t MXdlPerWave,
140  index_t NXdlPerWave,
141  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
142  typename ABlockTransferThreadClusterArrangeOrder,
143  typename ABlockTransferSrcAccessOrder,
144  index_t ABlockTransferSrcVectorDim,
145  index_t ABlockTransferSrcScalarPerVector,
146  index_t ABlockTransferDstScalarPerVector_AK1,
147  bool AThreadTransferSrcResetCoordinateAfterRun,
148  index_t ABlockLdsExtraM,
149  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
150  typename BBlockTransferThreadClusterArrangeOrder,
151  typename BBlockTransferSrcAccessOrder,
152  index_t BBlockTransferSrcVectorDim,
153  index_t BBlockTransferSrcScalarPerVector,
154  index_t BBlockTransferDstScalarPerVector_BK1,
155  bool BThreadTransferSrcResetCoordinateAfterRun,
156  index_t BBlockLdsExtraN,
157  index_t CShuffleMXdlPerWavePerShuffle,
158  index_t CShuffleNXdlPerWavePerShuffle,
159  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
160  typename CDEShuffleBlockTransferScalarPerVectors,
163  index_t ActivationOperation = 0,
164  bool NSwizzle = false,
165  bool IsInputGemm = true,
166  bool MulRoutedWeight = true,
167  typename IndexType = index_t,
168  typename ComputeTypeA = ADataType,
169  typename ComputeTypeB = BDataType>
171 {
172  using LDSTypeA = ADataType;
173  using LDSTypeB = BDataType;
174 
175  static constexpr auto I0 = Number<0>{};
176  static constexpr auto I1 = Number<1>{};
177  static constexpr auto I2 = Number<2>{};
178  static constexpr auto I3 = Number<3>{};
179  static constexpr auto I4 = Number<4>{};
180  static constexpr auto I5 = Number<5>{};
181  static constexpr auto I6 = Number<6>{};
182  static constexpr auto I7 = Number<7>{};
183  static constexpr auto I8 = Number<8>{};
184  static constexpr auto I9 = Number<9>{};
185 
187  CDEShuffleBlockTransferScalarPerVectors{}[I0];
188  // K1 should be Number<...>
189  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
190  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
191  static constexpr auto AK1Number = Number<AK1Value>{};
192  static constexpr auto BK1Number = Number<BK1Value>{};
193 
194  static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
195  static constexpr bool is_single_rate_mfma = false;
196  static constexpr auto is_scale_mfma = true;
197 
198  static constexpr index_t NumDTensor = DsDataType::Size();
199 
200  static constexpr auto MXdlPack = 2;
201  static constexpr auto NXdlPack = 2;
202  static constexpr auto KXdlPack = 2;
203 
204  //> KPack is at least the k_per_blk of selected mfma
205  //
206  // Should be a multiple of k_per_blk.
207  // TODO: Move this to blockwise pipeline base
208  // KPack in packed data types for pk A/B
209 
210  static constexpr index_t APackedSize = packed_size_v<ADataType>;
211  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
212 
213  using mfma_selector = MfmaSelector<ComputeTypeA,
214  MPerXdl,
215  NPerXdl,
216  ComputeTypeB,
218  is_scale_mfma>;
219  static constexpr index_t KPack =
221 
222  static constexpr index_t NLane = NPerXdl;
223  static constexpr index_t KLane = 64 / NLane;
224  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
225  static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
226 
227  // static constexpr index_t NumTokens = 1;
228  static constexpr index_t SortedTileSize = MPerBlock;
229 
231  static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
232  static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
233  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
234  "A scale pack data type too large!");
235  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
236  "B scale pack data type too large!");
237 
238  static constexpr auto MakeDsGridPointer()
239  {
240  return generate_tuple(
241  [&](auto i) {
242  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
243 
244  return static_cast<const DDataType*>(nullptr);
245  },
247  }
248 
249  using DsGridPointer = decltype(MakeDsGridPointer());
250 
252 
253  __host__ static auto CalculateGridSize(index_t M, index_t N)
254  {
255  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
256  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
257  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
258  const index_t gridy = NSwizzle ? 1 : mblock;
259 
260  return std::make_tuple(gridx, gridy, 1);
261  }
262 
263  __host__ static auto CalculateMPadded(index_t M)
264  {
265  return math::integer_least_multiple(M, MPerBlock);
266  }
267 
268  __host__ static auto CalculateNPadded(index_t N)
269  {
270  return math::integer_least_multiple(N, NPerBlock);
271  }
272 
273  __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
274  {
275  return math::integer_divide_ceil(N, NLane);
276  }
277  __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
278  {
280  }
281 
282  __host__ static auto CalculateKPadded(index_t K)
283  {
284  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
285  }
286 
287  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
288  {
289  auto K_t = K_Batch * KPerBlock;
290  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
291  }
292 
293  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
294  {
295  auto K_t = K_Batch * KPerBlock;
296  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
297  }
298 
299  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
300  {
301  auto K_t = K_Batch * KPerBlock;
302  return (K + K_t - 1) / K_t * KPerBlock;
303  }
304 
305  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
306  {
307  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
308  auto K_t = K_Batch * KReadVec;
309  return (K + K_t - 1) / K_t * KReadVec;
310  }
311 
312  __host__ static auto CalculateMBlock(index_t M)
313  {
314  return math::integer_divide_ceil(M, MPerBlock);
315  }
316 
317  __host__ static auto CalculateNBlock(index_t N)
318  {
319  return math::integer_divide_ceil(N, NPerBlock);
320  }
321 
322  template <index_t MNXdlPerWave,
323  index_t MNWaves,
324  index_t MNXdlPack,
325  index_t MNPerXdl,
326  bool IsXor,
327  typename TileDesc_K0_MN_K1>
328  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
329  {
330  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
331  constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
332  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
333 
334  if constexpr(IsXor)
335  {
336  constexpr auto permuted_desc = transform_tensor_descriptor(
337  TileDesc_K0_MN_K1{},
342 
344  permuted_desc,
345  make_tuple(
348  Number<MNWaves>{},
350  Number<MNPerXdl>{}))),
353  }
354  else
355  {
357  TileDesc_K0_MN_K1{},
358  make_tuple(
361  Number<MNWaves>{},
363  Number<MNPerXdl>{}))),
366  }
367  }
368 
369  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
370  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
371  {
372  const auto a_grid_desc_mraw_kraw = [&]() {
373  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
374  {
375  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
376  }
377  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
378  {
379  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
380  }
381  }();
382 
384 
385  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
386  GemmSpec == GemmSpecialization::MNKPadding)
387  {
388  // pad both M and K
389  const auto a_grid_desc_m_k =
390  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
392  make_right_pad_transform(K, KPad - K)),
395 
396  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
397  a_grid_desc_m_k,
402 
403  return a_grid_desc_ak0_m_ak1;
404  }
405  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
406  GemmSpec == GemmSpecialization::MNPadding)
407  {
408  // pad M, but not K
409  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
410  a_grid_desc_mraw_kraw,
412  make_right_pad_transform(M, MPad - M)),
415 
416  return a_grid_desc_ak0_m_ak1;
417  }
418  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
419  GemmSpec == GemmSpecialization::NKPadding)
420  {
421  // pad K, but not M
422  const auto a_grid_desc_m_k = transform_tensor_descriptor(
423  a_grid_desc_mraw_kraw,
427 
428  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
429  a_grid_desc_m_k,
434 
435  return a_grid_desc_ak0_m_ak1;
436  }
437  else
438  {
439  // not pad M or K
440  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
441  a_grid_desc_mraw_kraw,
442  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
446 
447  const auto a_grid_desc_permuted = transform_tensor_descriptor(
448  a_grid_desc_ak0_m_ak1,
451  make_pass_through_transform(AK1Value)),
454 
455  const auto a_grid_desc = transform_tensor_descriptor(
456  a_grid_desc_permuted,
457  make_tuple(
460  make_pass_through_transform(AK1Value)),
463 
464  return a_grid_desc;
465  }
466  }
467 
468  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
469  {
470  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
471  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
472  constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack>{};
474  make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
475  }
476 
477  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
478  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
479  {
480  const auto b_grid_desc_nraw_kraw = [&]() {
482  {
483  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
484  }
486  {
487  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
488  }
489  }();
490 
492 
493  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
494  GemmSpec != GemmSpecialization::Default),
495  "pk_i4_t does not support padding");
496  static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
497  GemmSpec != GemmSpecialization::Default),
498  "f4x2_pk_t does not support padding");
499 
500  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
501  GemmSpec == GemmSpecialization::MNKPadding)
502  {
503  // pad both N and K
504  const auto b_grid_desc_n_k =
505  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
507  make_right_pad_transform(K, KPad - K)),
510 
511  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
512  b_grid_desc_n_k,
517 
518  return b_grid_desc_bk0_n_bk1;
519  }
520  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
521  GemmSpec == GemmSpecialization::MNPadding)
522  {
523  // pad N, but not K
524  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
525  b_grid_desc_nraw_kraw,
527  make_right_pad_transform(N, NPad - N)),
530 
531  return b_grid_desc_bk0_n_bk1;
532  }
533  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
534  GemmSpec == GemmSpecialization::MKPadding)
535  {
536  // pad K, but not N
537  const auto b_grid_desc_n_k = transform_tensor_descriptor(
538  b_grid_desc_nraw_kraw,
542 
543  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
544  b_grid_desc_n_k,
549 
550  return b_grid_desc_bk0_n_bk1;
551  }
552  else
553  {
554  // not pad N or K
555  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
556  b_grid_desc_nraw_kraw,
557  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
561 
562  const auto b_grid_desc_permuted = transform_tensor_descriptor(
563  b_grid_desc_bk0_n_bk1,
566  make_pass_through_transform(BK1Value)),
569 
570  const auto b_grid_desc = transform_tensor_descriptor(
571  b_grid_desc_permuted,
572  make_tuple(
575  make_pass_through_transform(BK1Value)),
578 
579  return b_grid_desc;
580  }
581  }
582 
583  template <typename ABlockDesc_AK0_M_AK1>
584  __host__ __device__ static constexpr auto
585  MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
586  {
587  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
588 
589  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl, true>(
590  ABlockDesc_AK0_M_AK1{});
591  }
592 
593  template <typename BBlockDesc_BK0_N_BK1>
594  __host__ __device__ static constexpr auto
595  MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
596  {
597  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
598 
599  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl, false>(
600  BBlockDesc_BK0_N_BK1{});
601  }
602 
603  template <typename ELayout>
604  __host__ __device__ static auto MakeCGridDescriptor_M_N(
605  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
606  {
607  const auto c_grid_desc_mraw_nraw = [&]() {
609  {
610  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
611  }
613  {
614  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
615  }
616  }();
617 
618  // pad M and N
619  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
621  make_right_pad_transform(N, NPad - N)),
624  }
625 
626  template <typename DLayout>
627  __host__ __device__ static auto
629  {
630  const auto c_grid_desc_mraw_nraw = [&]() {
632  {
633  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
634  }
636  {
637  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
638  }
639  }();
640 
641  // pad M and N
642  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
644  make_right_pad_transform(N, NPad - N)),
647  }
648 
649  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
650  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
651  {
652  return generate_tuple(
653  [&](auto i) {
654  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
655  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
656  },
658  }
659 
660  template <typename DsGridDesc>
662  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
663  {
664  return generate_tuple(
665  [&](auto i) {
667  ds_grid_desc_m_n[i], MBlock, NBlock);
668  },
670  }
671 
672  struct Problem
673  {
674  __host__ Problem(index_t NumTokens_,
675  index_t TopK_,
676  index_t M_,
677  index_t N_,
678  index_t K_,
679  index_t StrideA_,
680  index_t StrideScaleA_,
681  index_t StrideB_,
682  index_t StrideScaleB_,
683  std::array<index_t, NumDTensor> StrideDs_,
684  index_t StrideC_,
685  index_t KBatch_)
686  : NumTokens{NumTokens_},
687  TopK{TopK_},
688  M{M_},
689  N{N_},
690  K{K_},
691  StrideA{StrideA_},
692  StrideScaleA{StrideScaleA_},
693  StrideB{StrideB_},
694  StrideScaleB{StrideScaleB_},
695  StrideDs{StrideDs_},
696  StrideC{StrideC_},
697  KBatch{KBatch_},
700  KRead{CalculateKRead(K_, KBatch_)},
701  KPadded{CalculateKPadded(K_, KBatch_)},
702  AK0{CalculateAK0Padded(K_, KBatch_)},
703  BK0{CalculateBK0Padded(K_, KBatch_)},
704  MBlock{CalculateMBlock(M_)},
706  {
707  }
708 
709  __host__ void Print() const
710  {
711  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
712  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
713  << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
714  << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
715  << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
716  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
717  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
718  << ", " << "NBlock: " << NBlock << "}" << std::endl;
719  }
720 
730  std::array<index_t, NumDTensor> StrideDs;
741  };
742 
743  // Argument
745  {
746  __host__ Argument(const index_t* p_sorted_token_ids_,
747  const index_t* p_sorted_expert_ids_,
748  const index_t* p_max_token_id_,
749  const ADataType* p_a_grid_,
750  const AScaleDataType* p_a_scale_grid_,
751  const BDataType* p_b_grid_,
752  const BScaleDataType* p_b_scale_grid_,
753  std::array<const void*, NumDTensor> p_ds_grid_,
754  CDataType* p_c_grid_,
755  index_t NumTokens_,
756  index_t TopK_,
757  index_t M_,
758  index_t N_,
759  index_t K_,
760  index_t StrideA_,
761  index_t StrideScaleA_,
762  index_t StrideB_,
763  index_t StrideScaleB_,
764  std::array<index_t, NumDTensor> StrideDs_,
765  index_t StrideC_,
766  index_t k_batch_,
767  AElementwiseOperation a_element_op_,
768  BElementwiseOperation b_element_op_,
769  CElementwiseOperation c_element_op_)
770  : Problem{NumTokens_,
771  TopK_,
772  M_,
773  N_,
774  K_ / APackedSize,
775  StrideA_ / APackedSize,
776  StrideScaleA_,
777  StrideB_ / BPackedSize,
778  StrideScaleB_,
779  StrideDs_,
780  StrideC_,
781  k_batch_},
782  p_sorted_token_ids{p_sorted_token_ids_},
783  p_sorted_expert_ids{p_sorted_expert_ids_},
784  p_max_token_id{p_max_token_id_},
785  p_a_grid{p_a_grid_},
786  p_a_scale_grid{p_a_scale_grid_},
787  p_b_grid{p_b_grid_},
788  p_b_scale_grid{p_b_scale_grid_},
789  p_ds_grid{},
790  p_c_grid{p_c_grid_},
791  a_element_op{a_element_op_},
792  b_element_op{b_element_op_},
793  c_element_op{c_element_op_}
794  {
795 
796  // populate pointer, desc for Ds
797  static_for<0, NumDTensor, 1>{}([&](auto i) {
798  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
799 
800  // D pointer
801  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
802  });
803  }
804 
808  const ADataType* p_a_grid;
809  const AScaleDataType* p_a_scale_grid;
810  const BDataType* p_b_grid;
811  const BScaleDataType* p_b_scale_grid;
813  CDataType* p_c_grid;
814 
815  const AElementwiseOperation a_element_op;
816  const BElementwiseOperation b_element_op;
817  const CElementwiseOperation c_element_op;
818  };
819 
821  {
822  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
823  {
824  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
825  {
826  a_k_split_offset = k_id * karg.KRead;
827  }
828  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
829  {
830  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
831  }
832 
833  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
834  {
835  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
836  }
837  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
838  {
839  // KPack * NLane * KLane * K0 * N0
840  b_k_split_offset = k_id * karg.KRead * NPerXdl;
841  }
842 
843  // Calculate A scale offset
844  a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack *
845  MPerXdl / scale_pack_size_a;
846 
847  // Calculate B scale offset
848  b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack *
849  NPerXdl / scale_pack_size_b;
850 
851  if(k_id < karg.KBatch - 1)
852  {
853  karg.K = karg.KRead;
854  }
855  else
856  {
857  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
858  }
859  }
860 
865  };
866 
867  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
868  {
869  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
870  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
871 
872  // A matrix in LDS memory, dst of blockwise copy
873  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
874  {
875  // contiguous in LDS
879  }
880  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
881  // in some cases.
883  {
884  constexpr auto a_lds_block_desc =
887 
888  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
889  a_lds_block_desc,
895 
896  return a_lds_block_desc_permuted;
897  }
898  else // ColumnMajor A
899  {
900  // kfold and mpair dimension is not always required.
901  // more dimension in merge_transform increase the difficulty of generating immarg offset
902  // for compiler.
903  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
904  constexpr auto M1 = MPerBlock / M0;
905 
906  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
907  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
908  constexpr auto KThreadRead = WaveSize / MPerXdl;
909  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
910 
911  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
912  ? 1
913  : 128 / (AK1Number * M0 * sizeof(ADataType));
914  constexpr auto KThreadReadPerm =
915  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
916  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
917  : KThreadRead;
918 
919  // 1<=mpair<=n0
920  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
921  ? 1
922  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
923  ? M0
924  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
925 
926  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
930  Number<kfold * M0 / mpair>{},
931  Number<mpair>{},
932  AK1Number));
933 
934  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
935  a_lds_block_desc,
936  make_tuple(
940  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
943  make_tuple(
945  make_tuple(
947 
948  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
949  a_lds_block_desc_permuted,
950  make_tuple(
958  Sequence<1>{},
959  Sequence<2>{},
960  Sequence<3>{},
961  Sequence<4>{},
962  Sequence<5>{}),
964  Sequence<2>{},
965  Sequence<0, 3>{},
966  Sequence<4, 5>{},
967  Sequence<6>{},
968  Sequence<7>{}));
969 
970  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
971  a_lds_block_desc_unmerged,
974  Number<KThreadWrite / kfold / KThreadReadPerm>{},
975  Number<kfold>{},
982 
983  return a_lds_block_desc_ak0_m_ak1;
984  }
985  }
986 
987  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
988  {
989  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
991  I1,
993  Number<KRepeat>{},
994  Number<BK1Value>{}));
995  }
996 
998  {
999  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1000 
1001  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1003  make_tuple(I1,
1005  I1,
1007 
1008  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1009  }
1010 
1013  BlkGemmPipelineVer,
1014  BlkGemmPipeSched,
1015  BlockSize,
1016  ScaleBlockSize,
1017  ADataType,
1018  AScaleDataType,
1019  BDataType,
1020  BScaleDataType,
1021  ComputeTypeA,
1022  AccDataType,
1029  ABlockTransferSrcScalarPerVector,
1030  BBlockTransferSrcScalarPerVector,
1031  MPerBlock,
1032  NPerBlock,
1033  KPerBlock,
1034  MPerXdl,
1035  NPerXdl,
1036  MXdlPerWave,
1037  NXdlPerWave,
1038  KPack,
1039  IsInputGemm>())>;
1040 
1041  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1042  {
1043  // LDS allocation for A and B: be careful of alignment
1044  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1045  // lds max alignment
1046  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1047 
1048  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1049  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1050 
1051  // LDS allocation for C shuffle in LDS
1052  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1054 
1055  constexpr auto c_block_size =
1056  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1057 
1058  return math::max(a_block_space_size_aligned * sizeof(ADataType),
1059  c_block_size * sizeof(CShuffleDataType));
1060  }
1061 
1062  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1063  __host__ static constexpr bool CheckValidity(const Argument& karg)
1064  {
1065  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1066  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1067  "Invalid tuning param!");
1068 
1069  static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1070  "KPerBlock should be multiple of ScaleBlockSize");
1071 
1072  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1077  {
1078  if(!(karg.M % MPerBlock == 0))
1079  {
1080  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1081  {
1082  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1083  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1084  << std::endl;
1085  }
1086  return false;
1087  }
1088  }
1089 
1090  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1095  {
1096  if(!(karg.N % NPerBlock == 0))
1097  {
1098  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1099  {
1100  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1101  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1102  << std::endl;
1103  }
1104  return false;
1105  }
1106  }
1107 
1108  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1112  {
1113  auto K_t = karg.KBatch * KPerBlock;
1114  if(!(karg.K % K_t == 0))
1115  {
1116  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1117  {
1118  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1119  << karg.K << " " << __FILE__ << ":" << __LINE__
1120  << ", in function: " << __func__ << std::endl;
1121  }
1122  return false;
1123  }
1124  }
1125  else
1126  {
1127  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1128  auto K_t = karg.KBatch * KReadVec;
1129  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1130  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1131  {
1132  return false;
1133  }
1134  }
1135 
1137  {
1138  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1139  {
1140  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1141  {
1142  std::cout << "Arg K (" << karg.K
1143  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1144  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1145  << __LINE__ << ", in function: " << __func__ << std::endl;
1146  }
1147  return false;
1148  }
1149  }
1150  else
1151  {
1152  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1153  {
1154  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1155  {
1156  std::cout << "Arg M (" << karg.M
1157  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1158  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1159  << __LINE__ << ", in function: " << __func__ << std::endl;
1160  }
1161  return false;
1162  }
1163  }
1164 
1166  {
1167  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1168  {
1169  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1170  {
1171  std::cout << "Arg N (" << karg.N
1172  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1173  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1174  << __LINE__ << ", in function: " << __func__ << std::endl;
1175  }
1176  return false;
1177  }
1178  }
1179  else
1180  {
1181  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1182  {
1183  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1184  {
1185  std::cout << "Arg K (" << karg.K
1186  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1187  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1188  << __LINE__ << ", in function: " << __func__ << std::endl;
1189  }
1190  return false;
1191  }
1192  }
1193 
1195  {
1197  {
1198  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1199  {
1200  std::cout << "Arg N (" << karg.N
1201  << ") value is not a multiple of "
1202  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1204  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1205  << std::endl;
1206  }
1207  return false;
1208  }
1209  }
1210  else
1211  {
1213  {
1214  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1215  {
1216  std::cout << "Arg M (" << karg.M
1217  << ") value is not a multiple of "
1218  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1220  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1221  << std::endl;
1222 
1223  return false;
1224  }
1225  }
1226  }
1227 
1228  // check gridwise gemm pipeline
1229 #if 0
1230  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1231 
1232  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1233  {
1234  return false;
1235  }
1236 #endif
1237  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1238  return true;
1239  }
1240 
1241  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1242  {
1243  const index_t num_loop = K / KPerBlock;
1244 
1245  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1246  }
1247 
1248  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1249  {
1250  const index_t num_loop = K / KPerBlock;
1251 
1252  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1253  }
1254 
1255  template <typename CGridDesc>
1256  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1257  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1258  {
1259  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1260  c_grid_desc_m_n,
1265 
1266  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1267  }
1268 
1269  // return block_id to C matrix tile idx (m0, n0) mapping
1270  // if arch = gfx942
1271  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1272  // NPerBlock>;
1273 
1274 #if 0
1275  template <bool HasMainKBlockLoop,
1276  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1277  TailNumber TailNum = TailNumber::Odd>
1278  __device__ static void Run(const index_t* p_sorted_token_ids,
1279  const index_t* p_sorted_expert_ids,
1280  const index_t* p_max_token_id,
1281  const ADataType* p_a_grid,
1282  const AScaleDataType* p_a_scale_grid,
1283  const BDataType* p_b_grid,
1284  const BScaleDataType* p_b_scale_grid,
1285  DsGridPointer& p_ds_grid,
1286  CDataType* p_c_grid,
1287  void* p_shared,
1288  const Problem& problem,
1289  AElementwiseOperation a_element_op,
1290  BElementwiseOperation b_element_op,
1291  CElementwiseOperation c_element_op)
1292  {
1293  ignore = b_element_op;
1294  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1295  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1296  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1297  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1298  problem.MPadded,
1299  problem.K,
1300  problem.KPadded,
1301  problem.StrideA,
1302  problem.AK0);
1303  const auto b_grid_desc_bpreshuffled =
1304  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1305  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1306  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1307  problem.MPadded,
1308  problem.N,
1309  problem.NPadded,
1310  problem.StrideC);
1311 
1312  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1313  make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock),
1314  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1315  (KXdlPack * 64 / MPerXdl),
1316  64 * KXdlPack * MXdlPack / scale_pack_size_a));
1317 
1318  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1319  make_tuple(problem.N / (NXdlPack * NPerXdl),
1320  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1321  (KXdlPack * 64 / NPerXdl),
1322  64 * KXdlPack * NXdlPack / scale_pack_size_b));
1323 
1324  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1326  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1327  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1328  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1329  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1330  if(expert_block_id * MPerBlock >= max_token_id)
1331  return;
1332  const index_t expert_id =
1333  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1334 
1335  const auto block_mn = [&]() -> std::pair<int, int> {
1336  if constexpr(NSwizzle)
1337  {
1338  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1339  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1340  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1341  const index_t expert_swizzle =
1342  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1343  const index_t bid_new = blockIdx.x - prefix_block;
1344  const index_t nid = __builtin_amdgcn_readfirstlane(
1345  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1346  const index_t mid =
1347  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1348  return {nid, mid};
1349  }
1350  else
1351  {
1352  return {blockIdx.x, blockIdx.y};
1353  }
1354  }();
1355 
1356  const index_t block_n_id = block_mn.first;
1357  const index_t block_m_id = block_mn.second;
1358  const index_t token0 =
1359  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1360 
1361  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1362  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1363  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1364  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1365  constexpr auto AKThreads = AK0Threads * AK1Threads;
1366  constexpr auto AMRepeats = MPerBlock / AMThreads;
1367  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1368 
1369  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1370  return;
1371  StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
1372  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1373  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1374  index_t token_offset = fused_token & 0xffffff;
1375  if constexpr(!IsInputGemm)
1376  {
1377  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1378  }
1379  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K / APackedSize;
1380  });
1381  const index_t expert_stride =
1382  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1383  const index_t expert_scale_stride =
1384  __builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) *
1385  math::integer_divide_ceil(problem.K, ScaleBlockSize));
1386 
1387  // N0, K0, Blocksize*KPack
1388  const index_t n_block_data_idx_on_grid =
1389  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1390 
1391  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1392  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1393  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1394  p_b_grid + expert_id * expert_stride / BPackedSize,
1395  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1396 
1397  // A, B scale buffer
1398  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1399  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1400  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1401  p_b_scale_grid + expert_id * expert_scale_stride,
1402  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1403 
1404  // A matrix in LDS memory, dst of blockwise copy
1405  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1406 
1407  // B matrix in LDS memory, dst of blockwise copy
1408  // dummy
1409  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1410  // A matrix blockwise copy
1411  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1413  AElementwiseOperation,
1416  Sequence<AK0Number, MPerBlock, AK1Number>,
1417  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1418  ABlockTransferThreadClusterArrangeOrder,
1419  ADataType,
1420  LDSTypeA,
1421  decltype(a_grid_desc_ak0_m_ak1),
1422  decltype(a_block_desc_ak0_m_ak1),
1423  ABlockTransferSrcAccessOrder,
1424  Sequence<0, 1, 2>,
1425  ABlockTransferSrcVectorDim,
1426  2,
1427  ABlockTransferSrcScalarPerVector,
1428  ABlockTransferDstScalarPerVector_AK1,
1429  1,
1430  1,
1431  AThreadTransferSrcResetCoordinateAfterRun,
1432  true,
1433  IndexType,
1434  1,
1435  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1436  make_multi_index(0, 0, 0),
1437  a_element_op,
1438  a_block_desc_ak0_m_ak1,
1439  make_multi_index(0, 0, 0),
1441  gather_offsets);
1442 
1443  // Thread-wise copy
1444  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1445  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1446  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1447 
1448  auto b_blockwise_copy =
1449  ThreadwiseTensorSliceTransfer_v2<BDataType,
1450  BDataType,
1451  decltype(b_grid_desc_bpreshuffled),
1452  decltype(b_block_desc_bk0_n_bk1),
1453  Sequence<Number<NXdlPerWave / NXdlPack>{},
1454  I1,
1455  Number<NXdlPack>{},
1456  Number<KRepeat>{},
1457  Number<BK1Value>{}>,
1458  Sequence<1, 2, 0, 3>,
1459  4,
1460  BBlockTransferSrcScalarPerVector,
1461  BThreadTransferSrcResetCoordinateAfterRun,
1462  true>(
1463  b_grid_desc_bpreshuffled,
1464  make_multi_index(n_block_data_idx_on_grid,
1466  0,
1467  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1468 
1469  // LDS allocation for A and B: be careful of alignment
1470  // Cast after lds
1471  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1472  static_cast<LDSTypeA*>(p_shared),
1473  a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize);
1474 
1475  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1476  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1477 
1478  // Blockwise GEMM pipeline
1479  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1480  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1481  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1482  decltype(c_thread_buf) c_thread_buf_up;
1483 
1484  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
1485  float,
1486  c_thread_buf.num_of_v_,
1487  c_thread_buf.s_per_v,
1488  true>
1489  c_thread_buf_fp32;
1490 
1491  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1492  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1493  KPerBlock);
1494 
1495  // a and b scale processing
1496  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1497  const auto waveId_m = wave_idx[I0];
1498  const auto waveId_n = wave_idx[I1];
1499 
1500  static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
1501 
1502  auto thread_offset_shuffled =
1503  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1504 
1505  auto a_thread_offset_m = waveId_m;
1506 
1507  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1508  AScaleDataType,
1509  AScaleDataType,
1510  decltype(a_scale_grid_desc_am_ak),
1511  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1512  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1513  Sequence<0, 1, 2>, // DimAccessOrder
1514  2, // SrcVectorDim
1515  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1516  1, // SrcScalarStrideInVector
1517  true>(a_scale_grid_desc_am_ak,
1518  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1519  0,
1520  thread_offset_shuffled / scale_pack_size_a));
1521 
1522  // B scale load
1523  auto b_thread_offset_n = waveId_n;
1524 
1525  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1526  BScaleDataType,
1527  BScaleDataType,
1528  decltype(b_scale_grid_desc_bn_ak),
1529  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1530  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1531  Sequence<0, 1, 2>, // DimAccessOrder
1532  2, // SrcVectorDim
1533  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1534  1, // SrcScalarStrideInVector
1535  true>(b_scale_grid_desc_bn_ak,
1536  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1537  0,
1538  thread_offset_shuffled / scale_pack_size_b));
1539 
1540  if constexpr(IsInputGemm)
1541  {
1542  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1543  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1544  p_b_grid_up + expert_id * expert_stride / BPackedSize,
1545  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1546  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1547  BDataType,
1548  BDataType,
1549  decltype(b_grid_desc_bpreshuffled),
1550  decltype(b_block_desc_bk0_n_bk1),
1551  Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
1552  Sequence<1, 2, 0, 3>,
1553  3,
1554  BBlockTransferSrcScalarPerVector,
1555  BThreadTransferSrcResetCoordinateAfterRun,
1556  true>(b_grid_desc_bpreshuffled,
1557  make_multi_index(n_block_data_idx_on_grid,
1559  0,
1560  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1561  const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
1562  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1563  p_b_scale_grid_up + expert_id * expert_scale_stride,
1564  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1565  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1566  BScaleDataType,
1567  BScaleDataType,
1568  decltype(b_scale_grid_desc_bn_ak),
1569  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1570  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1571  Sequence<0, 1, 2>, // DimAccessOrder
1572  2, // SrcVectorDim
1573  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1574  1, // SrcScalarStrideInVector
1575  true>(
1576  b_scale_grid_desc_bn_ak,
1577  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1578  0,
1579  thread_offset_shuffled / scale_pack_size_b));
1580 
1581  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1582  a_grid_desc_ak0_m_ak1,
1583  a_block_desc_ak0_m_ak1,
1584  a_blockwise_copy,
1585  a_grid_buf,
1586  a_block_buf,
1587  a_block_slice_copy_step,
1588  b_grid_desc_bpreshuffled,
1589  b_block_desc_bk0_n_bk1,
1590  b_blockwise_copy,
1591  b_blockwise_copy_up,
1592  b_grid_buf,
1593  b_grid_buf_up,
1594  b_block_buf,
1595  b_block_slice_copy_step,
1596  c_thread_buf,
1597  c_thread_buf_up,
1598  a_scale_grid_desc_am_ak,
1599  a_scale_thread_copy,
1600  a_scale_grid_buf,
1601  b_scale_grid_desc_bn_ak,
1602  b_scale_thread_copy,
1603  b_scale_thread_copy_up,
1604  b_scale_grid_buf,
1605  b_scale_grid_buf_up,
1606  num_k_block_main_loop);
1607  }
1608  else
1609  {
1610  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1611  a_grid_desc_ak0_m_ak1,
1612  a_block_desc_ak0_m_ak1,
1613  a_blockwise_copy,
1614  a_grid_buf,
1615  a_block_buf,
1616  a_block_slice_copy_step,
1617  b_grid_desc_bpreshuffled,
1618  b_block_desc_bk0_n_bk1,
1619  b_blockwise_copy,
1620  b_grid_buf,
1621  b_block_buf,
1622  b_block_slice_copy_step,
1623  c_thread_buf,
1624  a_scale_grid_desc_am_ak,
1625  a_scale_thread_copy,
1626  a_scale_grid_buf,
1627  b_scale_grid_desc_bn_ak,
1628  b_scale_thread_copy,
1629  b_scale_grid_buf,
1630  num_k_block_main_loop);
1631  }
1632 
1633  // shuffle C and write out
1634  {
1635  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1636  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1637  "wrong!");
1638 
1639  // TODO: hacky, fix it!
1640  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1641  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1642 
1643  // TODO: hacky, fix it!
1644  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1645  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1646  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1647 
1648  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1649  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1650  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1651  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1652  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1653  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1654  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1655  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1656 
1657  // mul scales
1658  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
1659  static_assert(M4 == 4);
1660  const index_t m1 = get_warp_local_1d_id() / NWave;
1661  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
1662 
1663  vector_type<float, 4> topk_weights; // for gemm2 only
1664  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1665  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1666  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1667  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1668  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1669  if constexpr(MulRoutedWeight)
1670  {
1671  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1672  p_ds_grid[I2] + m_pos);
1673  }
1674  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1675  constexpr index_t c_offset =
1676  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1677  make_tuple(m0, n0, m2 * M4 + m4));
1678  constexpr auto cidx = Number<c_offset>{};
1679 
1680  if constexpr(IsInputGemm) // gu fusion
1681  {
1682  if constexpr(ActivationOperation == Activation::silu_and_mul)
1683  {
1684  float gate = c_thread_buf[cidx];
1685  float up = c_thread_buf_up[cidx];
1686  if constexpr(MulRoutedWeight)
1687  {
1688  gate = gate * topk_weights.AsType<float>()[m4];
1689  up = up * topk_weights.AsType<float>()[m4];
1690  }
1691  tensor_operation::element_wise::Silu{}(gate, gate);
1692  c_thread_buf_fp32(cidx) = gate * up;
1693  }
1694  else if(ActivationOperation == Activation::gelu_and_mul)
1695  {
1696  float gate = c_thread_buf[cidx];
1697  float up = c_thread_buf_up[cidx];
1698  if constexpr(MulRoutedWeight)
1699  {
1700  gate = gate * topk_weights.AsType<float>()[m4];
1701  up = up * topk_weights.AsType<float>()[m4];
1702  }
1703  tensor_operation::element_wise::Gelu{}(gate, gate);
1704  c_thread_buf_fp32(cidx) = gate * up;
1705  }
1706  }
1707  else
1708  {
1709  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1710  if constexpr(MulRoutedWeight)
1711  {
1712  c_thread_buf_fp32(cidx) =
1713  topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
1714  }
1715  }
1716  });
1717  });
1718  });
1719  });
1720 
1721  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1723 
1724  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1725  static_cast<CShuffleDataType*>(p_shared),
1726  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1727 
1728  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1729  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1730  make_tuple(
1733  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1734  M1, // M1 = MWave
1735  M2, // M2 * M3 * M4 = MPerXdl
1736  M3,
1737  M4)),
1740  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1741  N1, // N1 = NWave
1742  N2))), // N2 = NPerXdl
1743  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
1744  make_tuple(
1745  Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
1746 
1747  // calculate origin of thread output tensor on global memory
1748  // blockwise GEMM c matrix starting index
1749  const auto c_thread_mtx_on_block =
1750  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1751 
1752  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1753  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1754 
1755  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1757  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1758  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
1759  make_tuple(Sequence<0>{}));
1760 
1761  const auto m_thread_data_on_block_idx =
1762  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1763  make_multi_index(m_thread_data_on_block));
1764 
1765  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1768  make_tuple(Sequence<0, 1, 2>{}),
1769  make_tuple(Sequence<0>{}));
1770 
1771  const auto n_thread_data_on_block_idx =
1772  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1773  make_multi_index(n_thread_data_on_block));
1774 
1775  // shuffle: threadwise copy C from VGPR to LDS
1776  auto c_thread_copy_vgpr_to_lds =
1777  ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
1778  CShuffleDataType,
1779  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1780  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1782  Sequence<CShuffleMXdlPerWavePerShuffle,
1783  CShuffleNXdlPerWavePerShuffle,
1784  I1,
1785  I1,
1786  M2,
1787  I1,
1788  M4,
1789  I1>,
1790  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
1791  7,
1792  1,
1794  1,
1795  true>{
1796  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1797  make_multi_index(0,
1798  0,
1799  m_thread_data_on_block_idx[I1],
1800  n_thread_data_on_block_idx[I1],
1801  m_thread_data_on_block_idx[I2],
1802  m_thread_data_on_block_idx[I3],
1803  m_thread_data_on_block_idx[I4],
1804  n_thread_data_on_block_idx[I2]),
1806 
1807  using EDataType = CDataType;
1808 
1809  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1810  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1811 
1812  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1814  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1815 
1816  const auto ds_grid_buf = generate_tuple(
1817  [&](auto i) {
1818  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1819  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1820  },
1821  Number<NumDTensor>{});
1822 
1823  // tuple of reference to C/Ds tensor descriptors
1824  const auto c_ds_desc_refs = concat_tuple_of_reference(
1825  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1826  generate_tie([&](auto i) -> const auto& // return type should be reference
1827  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1828  Number<NumDTensor>{}));
1829 
1830  // tuple of reference to C/Ds tensor descriptors
1831  const auto c_ds_buf_refs = concat_tuple_of_reference(
1832  tie(c_shuffle_block_buf),
1833  generate_tie([&](auto i) -> const auto& // return type should be reference
1834  { return ds_grid_buf[i]; },
1835  Number<NumDTensor>{}));
1836 
1837  // tuple of starting index of C/Ds blockwise copy
1838  const auto idx_c_ds_block_begin =
1841  [&](auto) {
1842  return make_multi_index(block_m_id, 0, block_n_id, 0);
1843  // return make_multi_index(block_work_idx[I0], 0,
1844  // block_work_idx[I1], 0);
1845  },
1846  Number<NumDTensor>{}));
1847 
1848  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1849  c_grid_desc_mblock_mperblock_nblock_nperblock;
1850 
1851  using CDEBlockTransferCluster =
1852  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1853  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1854  constexpr index_t scatter_weight_idx = 1; // hack fix felix
1855  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1857  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1858  Tuple<EDataType>,
1859  decltype(c_ds_desc_refs),
1860  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1861  CElementwiseOperation,
1862  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1863  // support arbitray type
1864  Sequence<1,
1865  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1866  1,
1867  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1868  CDEBlockTransferCluster,
1869  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1870  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1871  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1872  3, // index_t SrcVectorDim,
1873  3, // index_t DstVectorDim,
1874  CDEShuffleBlockTransferScalarPerVectors,
1877  Sequence<true>,
1879  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1880  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1881  IndexType,
1882  1, // ScatterDim
1883  true, // OutputScatter: false, only use scatter weights
1884  scatter_weight_idx // ScatterWeightIdx: ascale
1885  >{c_ds_desc_refs,
1886  idx_c_ds_block_begin,
1887  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1888  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1889  c_element_op};
1890 
1891  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1892  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1893  constexpr auto sfc_c_vgpr =
1894  SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
1895  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
1896  Sequence<CShuffleMXdlPerWavePerShuffle,
1897  CShuffleNXdlPerWavePerShuffle,
1898  1,
1899  1,
1900  M2,
1901  1,
1902  M4,
1903  1>>{};
1904 
1905  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1906 
1907  // space filling curve for shuffled blockwise C/D/E
1908  constexpr auto sfc_cde_block =
1909  SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
1910  Sequence<0, 2, 1, 3>,
1911  Sequence<1,
1912  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1913  1,
1914  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1915 
1916  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1917  constexpr auto EMThreads =
1918  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1919  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1920  constexpr auto ENThreads =
1921  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1922  static_for<0, num_access, 1>{}([&](auto access_id) {
1923  // make sure it's safe to write to LDS
1924  StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
1925 
1926  auto dstidx = sfc_cde_block.GetIndex(access_id);
1927  const index_t c_token_pos =
1928  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1929  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1930  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1931  IndexType token_offset = fused_token & 0xffffff;
1932  if constexpr(IsInputGemm)
1933  {
1934  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1935  }
1936  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
1937  });
1938 
1939  block_sync_lds();
1940 
1941  // each thread write its data from VGPR to LDS
1942  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1943  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1944  c_thread_buf_fp32,
1945  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1946  c_shuffle_block_buf);
1947 
1948  // make sure it's safe to read from LDS
1949  block_sync_lds();
1950 
1951  // each block copy its data from LDS to global
1952  cde_block_copy_lds_and_global.Run(
1953  c_ds_desc_refs,
1954  c_ds_buf_refs,
1955  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1956  tie(c_grid_buf),
1957  scatter_offsets);
1958 
1959  if constexpr(access_id < num_access - 1)
1960  {
1961  constexpr auto cde_lds_and_global_step =
1962  sfc_cde_block.GetForwardStep(access_id);
1963 
1964  // move on Ds
1965  static_for<0, NumDTensor, 1>{}([&](auto i) {
1966  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1967  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1968  });
1969 
1970  // move on E
1971  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1972  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1973  I0,
1974  cde_lds_and_global_step);
1975  }
1976  });
1977  }
1978  }
1979 #endif
1980 
1981  template <bool HasMainKBlockLoop,
1982  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1983  TailNumber TailNum = TailNumber::Odd>
1984  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1985  const index_t* p_sorted_expert_ids,
1986  const index_t* p_max_token_id,
1987  const ADataType* p_a_grid,
1988  const AScaleDataType* p_a_scale_grid,
1989  const BDataType* p_b_grid,
1990  const BScaleDataType* p_b_scale_grid,
1991  DsGridPointer& p_ds_grid,
1992  CDataType* p_c_grid,
1993  void* p_shared_0,
1994  void* p_shared_1,
1995  const Problem& problem,
1996  AElementwiseOperation a_element_op,
1997  BElementwiseOperation b_element_op,
1998  CElementwiseOperation c_element_op)
1999  {
2000  ignore = a_element_op;
2001  ignore = b_element_op;
2002  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
2003  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
2004  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2005  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2006  problem.MPadded,
2007  problem.K,
2008  problem.KPadded,
2009  problem.StrideA,
2010  problem.AK0);
2011  const auto b_grid_desc_bpreshuffled =
2012  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
2013  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2014  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2015  problem.MPadded,
2016  problem.N,
2017  problem.NPadded,
2018  problem.StrideC);
2019 
2020  // We pad the M unconditionaly for Scale
2021  const auto Padded_Scale_M =
2022  math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
2023  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
2024  make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
2025  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2026  (KXdlPack * 64 / MPerXdl),
2028  make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2029  (ScaleBlockSize / APackedSize)) *
2030  MPerXdl * MXdlPack / scale_pack_size_a,
2032  1));
2033 
2034  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
2035  make_tuple(problem.N / (NXdlPack * NPerXdl),
2036  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2037  (KXdlPack * 64 / NPerXdl),
2039  make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2040  (ScaleBlockSize / BPackedSize)) *
2041  NPerXdl * NXdlPack / scale_pack_size_b,
2043  1));
2044 
2045  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2047  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2048 
2049  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2050  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2051  if(expert_block_id * MPerBlock >= max_token_id)
2052  return;
2053  const index_t expert_id =
2054  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2055  const auto block_mn = [&]() -> std::pair<int, int> {
2056  if constexpr(NSwizzle)
2057  {
2058  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2059  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2060  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2061  const index_t expert_swizzle =
2062  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2063  const index_t bid_new = blockIdx.x - prefix_block;
2064  const index_t nid = __builtin_amdgcn_readfirstlane(
2065  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2066  const index_t mid =
2067  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2068  return {nid, mid};
2069  }
2070  else
2071  {
2072  return {blockIdx.x, blockIdx.y};
2073  }
2074  }();
2075 
2076  const index_t block_n_id = block_mn.first;
2077  const index_t block_m_id = block_mn.second;
2078  const index_t token0 =
2079  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2080 
2081  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2082  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2083  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2084  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2085  constexpr auto AKThreads = AK0Threads * AK1Threads;
2086  constexpr auto AMRepeats = MPerBlock / AMThreads;
2087  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2088 
2089  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2090  return;
2092  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2093  const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2094  index_t token_offset = fused_token & 0xffffff;
2095  if constexpr(!IsInputGemm)
2096  {
2097  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2098  }
2099  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2100  });
2101 
2102  const index_t expert_stride =
2103  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2104  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2105  problem.N * (IsInputGemm ? 2 : 1) *
2106  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2107 
2108  // N0, K0, Blocksize*KPack
2109  const index_t n_block_data_idx_on_grid =
2110  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
2111 
2112  // Gride buffer creation
2113  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2114  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2115  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2116  p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2117 
2118  // A, B scale buffer
2119  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2120  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2121  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2122  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2123  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2124 
2125  // A matrix in LDS memory, dst of blockwise copy
2126  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2127 
2128  // B matrix in LDS memory, dst of blockwise copy
2129  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2130 
2131  // A matrix blockwise direct to LDS copy
2132  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
2135  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2136  ABlockTransferThreadClusterArrangeOrder,
2137  ADataType,
2138  ADataType,
2139  decltype(a_grid_desc_ak0_m_ak1),
2140  decltype(a_block_desc_ak0_m_ak1),
2141  ABlockTransferSrcAccessOrder,
2142  ABlockTransferSrcVectorDim,
2143  2,
2144  ABlockTransferSrcScalarPerVector,
2145  IndexType,
2146  1>(a_grid_desc_ak0_m_ak1,
2147  make_multi_index(0, 0, 0),
2148  a_block_desc_ak0_m_ak1,
2149  make_multi_index(0, 0, 0),
2150  gather_offsets);
2151 
2152  // Thread-wise copy
2153  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2154  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2155  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2156  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2157  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2158  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2159 
2160  auto b_blockwise_copy =
2162  BDataType,
2163  decltype(b_grid_desc_bpreshuffled),
2164  decltype(b_block_desc_bk0_n_bk1),
2165  Sequence<Number<NXdlPerWave / NXdlPack>{},
2166  I1,
2167  Number<NXdlPack>{},
2168  Number<KRepeat>{},
2169  Number<BK1Value>{}>,
2171  4,
2172  BBlockTransferSrcScalarPerVector,
2173  BThreadTransferSrcResetCoordinateAfterRun,
2174  true>(
2175  b_grid_desc_bpreshuffled,
2176  make_multi_index(n_block_data_idx_on_grid,
2178  0,
2179  0,
2180  KPack * (get_thread_local_1d_id() % WarpSize)));
2181 
2182  // LDS allocation for A and B: be careful of alignment
2183  // Cast after lds
2184  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2185  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2186  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2187  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2188  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2189 
2190  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2191  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
2192 
2193  // Blockwise GEMM pipeline
2194  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2195  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2196  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2197  decltype(c_thread_buf) c_thread_buf_up;
2198 
2200  float,
2201  c_thread_buf.num_of_v_,
2202  c_thread_buf.s_per_v,
2203  true>
2204  c_thread_buf_fp32;
2205 
2206  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2207  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2208  KPerBlock);
2209 
2210  // a and b scale processing
2211  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2212  const auto waveId_m = wave_idx[I0];
2213  const auto waveId_n = wave_idx[I1];
2214 
2215  auto thread_offset_shuffled =
2216  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2217 
2218  auto a_thread_offset_m = waveId_m;
2219 
2220  // get each thread's offset int the scale tensor
2221  const index_t token_scale_pos = block_m_id * MPerBlock;
2222  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2223  return;
2224 
2225  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2226  AScaleDataType,
2227  AScaleDataType,
2228  decltype(a_scale_grid_desc_am_ak),
2229  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2230  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2231  Sequence<0, 1, 2>, // DimAccessOrder
2232  2, // SrcVectorDim
2233  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2234  1, // SrcScalarStrideInVector
2235  true>(a_scale_grid_desc_am_ak,
2236  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2237  0,
2238  thread_offset_shuffled / scale_pack_size_a));
2239 
2240  // B scale load
2241  auto b_thread_offset_n = waveId_n;
2242 
2243  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2244  BScaleDataType,
2245  BScaleDataType,
2246  decltype(b_scale_grid_desc_bn_ak),
2247  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2248  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2249  Sequence<0, 1, 2>, // DimAccessOrder
2250  2, // SrcVectorDim
2251  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2252  1, // SrcScalarStrideInVector
2253  true>(b_scale_grid_desc_bn_ak,
2254  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2255  0,
2256  thread_offset_shuffled / scale_pack_size_b));
2257 
2258  if constexpr(IsInputGemm)
2259  {
2260  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2261  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2262  p_b_grid_up + expert_id * expert_stride,
2263  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2264  auto b_blockwise_copy_up =
2266  BDataType,
2267  decltype(b_grid_desc_bpreshuffled),
2268  decltype(b_block_desc_bk0_n_bk1),
2269  Sequence<Number<NXdlPerWave / NXdlPack>{},
2270  I1,
2271  Number<NXdlPack>{},
2272  Number<KRepeat>{},
2273  Number<BK1Value>{}>,
2275  4,
2276  BBlockTransferSrcScalarPerVector,
2277  BThreadTransferSrcResetCoordinateAfterRun,
2278  true>(
2279  b_grid_desc_bpreshuffled,
2280  make_multi_index(n_block_data_idx_on_grid,
2282  0,
2283  0,
2284  KPack * (get_thread_local_1d_id() % WarpSize)));
2285  const BScaleDataType* p_b_scale_grid_up =
2286  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
2287  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2288  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
2289  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2290 
2291  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2292  BScaleDataType,
2293  BScaleDataType,
2294  decltype(b_scale_grid_desc_bn_ak),
2295  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2296  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2297  Sequence<0, 1, 2>, // DimAccessOrder
2298  2, // SrcVectorDim
2299  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2300  1, // SrcScalarStrideInVector
2301  true>(
2302  b_scale_grid_desc_bn_ak,
2303  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2304  0,
2305  thread_offset_shuffled / scale_pack_size_b));
2306 
2307  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2308  // A
2309  a_grid_desc_ak0_m_ak1,
2310  a_block_desc_ak0_m_ak1,
2311  a_blockwise_copy,
2312  a_grid_buf,
2313  a_block_bufs,
2314  a_block_slice_copy_step,
2315  // Gate and Up
2316  b_grid_desc_bpreshuffled,
2317  b_block_desc_bk0_n_bk1,
2318  b_blockwise_copy,
2319  b_blockwise_copy_up,
2320  b_grid_buf,
2321  b_grid_buf_up,
2322  b_block_bufs,
2323  b_block_slice_copy_step,
2324  // C
2325  c_thread_buf,
2326  c_thread_buf_up,
2327  // A scale
2328  a_scale_grid_desc_am_ak,
2329  a_scale_thread_copy,
2330  a_scale_grid_buf,
2331  // B scale
2332  b_scale_grid_desc_bn_ak,
2333  b_scale_thread_copy,
2334  b_scale_thread_copy_up,
2335  b_scale_grid_buf,
2336  b_scale_grid_buf_up,
2337  num_k_block_main_loop);
2338  }
2339  else
2340  {
2341  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2342  a_grid_desc_ak0_m_ak1, // A
2343  a_block_desc_ak0_m_ak1,
2344  a_blockwise_copy,
2345  a_grid_buf,
2346  a_block_bufs,
2347  a_block_slice_copy_step,
2348  b_grid_desc_bpreshuffled, // B
2349  b_block_desc_bk0_n_bk1,
2350  b_blockwise_copy,
2351  b_grid_buf,
2352  b_block_bufs,
2353  b_block_slice_copy_step,
2354  c_thread_buf, // C
2355  a_scale_grid_desc_am_ak, // A scale
2356  a_scale_thread_copy,
2357  a_scale_grid_buf,
2358  b_scale_grid_desc_bn_ak, // B scale
2359  b_scale_thread_copy,
2360  b_scale_grid_buf,
2361  num_k_block_main_loop);
2362  }
2363 
2364  // shuffle C and write out
2365  {
2366  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2367  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2368  "wrong!");
2369  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
2370  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
2371  "wrong!");
2372 
2373  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2374 
2375  // TODO: hacky, fix it!
2376  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2377  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2378 
2379  // TODO: hacky, fix it!
2380  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2381  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2382  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2383 
2384  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2385  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2386  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2387  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2388  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2389  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2390  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2391  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2392  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
2393  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
2394 
2395  // mul scales
2396 
2397  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2398  static_assert(M5 == 4);
2399  const index_t m1 = get_warp_local_1d_id() / NWave;
2400  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
2401 
2402  vector_type<float, 4> topk_weights; // for gemm2 only
2403  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
2404  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
2405  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
2406  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
2407  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
2408  const index_t m_pos = block_m_id * MPerBlock +
2409  m0 * M2 * M1 * M3 * M4 * M5 +
2410  m1 * M2 * M3 * M4 * M5 +
2411  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2412  if constexpr(MulRoutedWeight)
2413  {
2414  topk_weights =
2415  *c_style_pointer_cast<const vector_type<float, M5>*>(
2416  p_ds_grid[I2] + m_pos);
2417  }
2418  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
2419  constexpr index_t c_offset =
2420  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2421  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2422  constexpr auto cidx = Number<c_offset>{};
2423 
2424  if constexpr(IsInputGemm) // gu fusion
2425  {
2426  if constexpr(ActivationOperation ==
2428  {
2429  float gate = c_thread_buf[cidx];
2430  float up = c_thread_buf_up[cidx];
2431  if constexpr(MulRoutedWeight)
2432  {
2433  gate = gate * topk_weights.AsType<float>()[m5];
2434  up = up * topk_weights.AsType<float>()[m5];
2435  }
2437  c_thread_buf_fp32(cidx) = gate * up;
2438  }
2439  else if(ActivationOperation == Activation::gelu_and_mul)
2440  {
2441  float gate = c_thread_buf[cidx];
2442  float up = c_thread_buf_up[cidx];
2443  if constexpr(MulRoutedWeight)
2444  {
2445  gate = gate * topk_weights.AsType<float>()[m5];
2446  up = up * topk_weights.AsType<float>()[m5];
2447  }
2449  c_thread_buf_fp32(cidx) = gate * up;
2450  }
2451  }
2452  else
2453  {
2454  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2455  if constexpr(MulRoutedWeight)
2456  {
2457  c_thread_buf_fp32(cidx) =
2458  topk_weights.AsType<float>()[m5] *
2459  c_thread_buf_fp32[cidx];
2460  }
2461  }
2462  });
2463  });
2464  });
2465  });
2466  });
2467  });
2468 
2469  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2471 
2472  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2473  static_cast<CShuffleDataType*>(p_shared_0),
2474  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2475 
2476  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2477  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2478  make_tuple(
2481  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
2482  // shuffle
2483  M1, // M1 = MWave
2484  M2, // M2 * M3 * M4 = MPerXdl
2485  M3,
2486  M4,
2487  M5)),
2491  // per shuffle
2492  N1, // N1 = NWave
2493  N2, // N2 = NXdlPack
2494  N3))), // N3 = NPerXdl
2498  Sequence<>{},
2500 
2501  // calculate origin of thread output tensor on global memory
2502  // blockwise GEMM c matrix starting index
2503  const auto c_thread_mtx_on_block =
2504  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2505 
2506  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2507  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2508 
2509  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2511  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
2513  make_tuple(Sequence<0>{}));
2514 
2515  const auto m_thread_data_on_block_idx =
2516  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2517  make_multi_index(m_thread_data_on_block));
2518 
2519  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2521  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
2523  make_tuple(Sequence<0>{}));
2524 
2525  const auto n_thread_data_on_block_idx =
2526  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2527  make_multi_index(n_thread_data_on_block));
2528 
2529  // shuffle: threadwise copy C from VGPR to LDS
2530  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2531  AccDataType,
2532  CShuffleDataType,
2533  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2534  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2536  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2537  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2538  I1,
2539  I1,
2540  M2,
2541  N2,
2542  M3,
2543  I1,
2544  M5,
2545  I1>,
2547  9,
2548  1,
2550  1,
2551  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2552  make_multi_index(0,
2553  0,
2554  m_thread_data_on_block_idx[I1],
2555  n_thread_data_on_block_idx[I1],
2556  m_thread_data_on_block_idx[I2],
2557  n_thread_data_on_block_idx[I2],
2558  m_thread_data_on_block_idx[I3],
2559  m_thread_data_on_block_idx[I4],
2560  m_thread_data_on_block_idx[I5],
2561  n_thread_data_on_block_idx[I3]),
2563 
2564  using EDataType = CDataType;
2565 
2566  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2567  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2568 
2569  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2571  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2572 
2573  const auto ds_grid_buf = generate_tuple(
2574  [&](auto i) {
2575  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2576  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2577  },
2578  Number<NumDTensor>{});
2579 
2580  // tuple of reference to C/Ds tensor descriptors
2581  const auto c_ds_desc_refs = concat_tuple_of_reference(
2582  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2583  generate_tie([&](auto i) -> const auto& // return type should be reference
2584  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2585  Number<NumDTensor>{}));
2586 
2587  // tuple of reference to C/Ds tensor descriptors
2588  const auto c_ds_buf_refs = concat_tuple_of_reference(
2589  tie(c_shuffle_block_buf),
2590  generate_tie([&](auto i) -> const auto& // return type should be reference
2591  { return ds_grid_buf[i]; },
2592  Number<NumDTensor>{}));
2593 
2594  // tuple of starting index of C/Ds blockwise copy
2595  const auto idx_c_ds_block_begin =
2598  [&](auto) {
2599  return make_multi_index(block_m_id, 0, block_n_id, 0);
2600  // return make_multi_index(block_work_idx[I0], 0,
2601  // block_work_idx[I1], 0);
2602  },
2603  Number<NumDTensor>{}));
2604 
2605  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2606  c_grid_desc_mblock_mperblock_nblock_nperblock;
2607 
2608  using CDEBlockTransferCluster =
2609  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2610  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2611  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2612  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2614  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2616  decltype(c_ds_desc_refs),
2617  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2618  CElementwiseOperation,
2619  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2620  // Sequence support
2621  // arbitray type
2622  Sequence<1,
2623  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2624  1,
2625  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2626  CDEBlockTransferCluster,
2627  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2628  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2629  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2630  3, // index_t SrcVectorDim,
2631  3, // index_t DstVectorDim,
2632  CDEShuffleBlockTransferScalarPerVectors,
2637  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2638  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2639  IndexType,
2640  1, // ScatterDim
2641  true, // OutputScatter: false, only use scatter weights
2642  scatter_weight_idx // ScatterWeightIdx: ascale
2643  >{c_ds_desc_refs,
2644  idx_c_ds_block_begin,
2645  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2646  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2647  c_element_op};
2648 
2649  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2650  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2651 
2652  constexpr auto sfc_c_vgpr =
2653  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2654  NXdlPerWave / NXdlPack,
2655  1,
2656  1,
2657  MXdlPack,
2658  NXdlPack,
2659  M2,
2660  1,
2661  M4,
2662  1>,
2664  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2665  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2666  1,
2667  1,
2668  MXdlPack,
2669  NXdlPack,
2670  M2,
2671  1,
2672  M4,
2673  1>>{};
2674 
2675  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2676 
2677  // space filling curve for shuffled blockwise C/D/E
2678  constexpr auto sfc_cde_block =
2681  Sequence<1,
2682  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2683  1,
2684  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2685 
2686  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2687  constexpr auto EMThreads =
2688  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2689  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2690  constexpr auto ENThreads =
2691  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2692  static_for<0, num_access, 1>{}([&](auto access_id) {
2693  // make sure it's safe to write to LDS
2695 
2696  auto dstidx = sfc_cde_block.GetIndex(access_id);
2697  const index_t c_token_pos =
2698  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2699  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2700  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2701  IndexType token_offset = fused_token & 0xffffff;
2702  if constexpr(IsInputGemm)
2703  {
2704  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2705  }
2706  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2707  });
2708 
2709  block_sync_lds();
2710 
2711  // each thread write its data from VGPR to LDS
2712  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2713  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2714  c_thread_buf_fp32,
2715  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2716  c_shuffle_block_buf);
2717 
2718  // make sure it's safe to read from LDS
2719  block_sync_lds();
2720 
2721  // each block copy its data from LDS to global
2722  cde_block_copy_lds_and_global.Run(
2723  c_ds_desc_refs,
2724  c_ds_buf_refs,
2725  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2726  tie(c_grid_buf),
2727  scatter_offsets);
2728 
2729  if constexpr(access_id < num_access - 1)
2730  {
2731  constexpr auto cde_lds_and_global_step =
2732  sfc_cde_block.GetForwardStep(access_id);
2733 
2734  // move on Ds
2735  static_for<0, NumDTensor, 1>{}([&](auto i) {
2736  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2737  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2738  });
2739 
2740  // move on E
2741  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2742  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2743  I0,
2744  cde_lds_and_global_step);
2745  }
2746  });
2747  }
2748  }
2749 };
2750 
2751 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:56
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp:36
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:745
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:806
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:805
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:811
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:816
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:810
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:807
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:812
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:815
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:813
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:809
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:746
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:817
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:808
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:673
index_t AK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:737
index_t NPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:734
index_t KBatch
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:732
index_t BK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:738
index_t TopK
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:722
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:674
index_t KRead
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:735
index_t K
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:725
index_t MPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:733
index_t StrideC
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:731
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:729
index_t NumTokens
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:721
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:730
index_t MBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:739
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:727
index_t N
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:724
__host__ void Print() const
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:709
index_t StrideA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:726
index_t StrideB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:728
index_t M
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:723
index_t KPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:736
index_t NBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:740
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:821
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:863
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:864
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:861
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:822
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:862
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:171
static constexpr auto I6
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:181
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:299
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:191
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:189
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1041
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:249
static constexpr auto I1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:176
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:282
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:253
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:649
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1984
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1256
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:661
static constexpr index_t NLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:222
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:987
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:228
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:232
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:192
remove_cvref_t< decltype(BlockGemmMXBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1039
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1248
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:268
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:238
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:369
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:277
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:190
static constexpr auto I9
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:184
static constexpr auto I4
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:179
static constexpr auto I3
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:178
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:305
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:287
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:263
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:186
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:312
static constexpr index_t KRepeat
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:225
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:201
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:997
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:202
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:173
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:477
static constexpr auto I0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:175
static constexpr auto I5
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:180
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:231
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:194
static constexpr index_t KLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:223
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:867
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:200
static constexpr auto I8
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:183
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:273
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:595
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:604
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:195
static constexpr auto I2
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:177
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:585
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:210
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:198
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:317
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:211
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:219
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1241
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1063
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:172
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:293
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:628
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:328
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:251
static constexpr auto I7
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:182
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:468
static constexpr index_t NWave
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:224
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:196
Definition: xdlops_gemm.hpp:1126
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1647
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: data_type.hpp:41
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129