/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp Source File
gridwise_moe_mx_gemm_bns.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
18 
20 
21 #define DEBUG_LOG 0
22 
23 namespace ck {
24 
25 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26 // kernel function Blockers:
27 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28 // two lds chunks.
29 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30 // buffer when we declare __shared__ inside blkgemmpipe
31 
33 {
34  gelu_and_mul = 0,
35  silu_and_mul = 1
36 };
37 
38 template <typename GridwiseGemm,
39  bool HasMainKBlockLoop,
40  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
41  index_t MinimumOccupancy = 1,
42  TailNumber TailNum = TailNumber::Even>
43 __global__ void
44 #if CK_USE_LAUNCH_BOUNDS
45 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
46 #endif
47  // __attribute__((amdgpu_waves_per_eu(1, 1)))
48  kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
49 {
50 #if defined(__gfx9__)
51  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52 
53  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54 
55  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56  karg.p_sorted_token_ids,
57  karg.p_sorted_expert_ids,
58  karg.p_max_token_id,
59  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60  karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
61  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
62  karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
63  karg.p_ds_grid,
64  karg.p_c_grid,
65  p_shared,
66  karg,
67  karg.a_element_op,
68  karg.b_element_op,
69  karg.c_element_op);
70 #else
71  ignore = karg;
72 #endif // end of if (defined(__gfx9__))
73 }
74 
75 #if 0
76 template <typename GridwiseGemm,
77  bool HasMainKBlockLoop,
78  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
79  index_t MinimumOccupancy = 1,
80  TailNumber TailNum = TailNumber::Even>
81 __global__ void
82 #if CK_USE_LAUNCH_BOUNDS
83 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
84 #endif
85  // __attribute__((amdgpu_waves_per_eu(1, 1)))
86  kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
87 {
88 #if defined(__gfx9__)
89  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
90  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
91 
92  // auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
93 
94  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
95  karg.p_sorted_token_ids,
96  karg.p_sorted_expert_ids,
97  karg.p_max_token_id,
98  karg.p_a_grid,
99  karg.p_a_scale_grid,
100  karg.p_b_grid,
101  karg.p_b_scale_grid,
102  karg.p_ds_grid,
103  karg.p_c_grid,
104  p_shared,
105  p_shared1,
106  karg,
107  karg.a_element_op,
108  karg.b_element_op,
109  karg.c_element_op);
110 #else
111  ignore = karg;
112 #endif // end of if (defined(__gfx9__))
113 }
114 #endif
115 
116 template <typename ALayout,
117  typename BLayout,
118  typename DsLayout,
119  typename CLayout,
120  typename ADataType,
121  typename AScaleDataType,
122  typename BDataType,
123  typename BScaleDataType,
124  typename AccDataType,
125  typename CShuffleDataType,
126  typename DsDataType,
127  typename CDataType,
128  typename AElementwiseOperation,
129  typename BElementwiseOperation,
130  typename CElementwiseOperation,
132  index_t ScaleBlockSize,
133  index_t BlockSize,
134  index_t MPerBlock,
135  index_t NPerBlock,
136  index_t KPerBlock,
137  index_t AK1Value,
138  index_t BK1Value,
139  index_t MPerXdl,
140  index_t NPerXdl,
141  index_t MXdlPerWave,
142  index_t NXdlPerWave,
143  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
144  typename ABlockTransferThreadClusterArrangeOrder,
145  typename ABlockTransferSrcAccessOrder,
146  index_t ABlockTransferSrcVectorDim,
147  index_t ABlockTransferSrcScalarPerVector,
148  index_t ABlockTransferDstScalarPerVector_AK1,
149  bool AThreadTransferSrcResetCoordinateAfterRun,
150  index_t ABlockLdsExtraM,
151  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
152  typename BBlockTransferThreadClusterArrangeOrder,
153  typename BBlockTransferSrcAccessOrder,
154  index_t BBlockTransferSrcVectorDim,
155  index_t BBlockTransferSrcScalarPerVector,
156  index_t BBlockTransferDstScalarPerVector_BK1,
157  bool BThreadTransferSrcResetCoordinateAfterRun,
158  index_t BBlockLdsExtraN,
159  index_t CShuffleMXdlPerWavePerShuffle,
160  index_t CShuffleNXdlPerWavePerShuffle,
161  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162  typename CDEShuffleBlockTransferScalarPerVectors,
165  index_t ActivationOperation = 0,
166  bool NSwizzle = false,
167  bool IsInputGemm = true,
168  bool MulRoutedWeight = true,
169  typename IndexType = index_t,
170  typename ComputeTypeA = ADataType,
171  typename ComputeTypeB = BDataType>
173 {
174  using LDSTypeA = ADataType;
175  using LDSTypeB = BDataType;
176 
177  static constexpr auto I0 = Number<0>{};
178  static constexpr auto I1 = Number<1>{};
179  static constexpr auto I2 = Number<2>{};
180  static constexpr auto I3 = Number<3>{};
181  static constexpr auto I4 = Number<4>{};
182  static constexpr auto I5 = Number<5>{};
183  static constexpr auto I6 = Number<6>{};
184  static constexpr auto I7 = Number<7>{};
185  static constexpr auto I8 = Number<8>{};
186  static constexpr auto I9 = Number<9>{};
187 
189  CDEShuffleBlockTransferScalarPerVectors{}[I0];
190  // K1 should be Number<...>
191  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
192  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
193  static constexpr auto AK1Number = Number<AK1Value>{};
194  static constexpr auto BK1Number = Number<BK1Value>{};
195 
196  static constexpr index_t NumDTensor = DsDataType::Size();
197 
198  static constexpr auto MXdlPack = 2;
199  static constexpr auto NXdlPack = 2;
200  static constexpr auto KXdlPack = 2;
201 
202  static constexpr index_t APackedSize = packed_size_v<ADataType>;
203  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
204 
205  static constexpr bool is_single_rate_mfma = false;
206  static constexpr auto is_scale_mfma = true;
207  using mfma_selector = MfmaSelector<ComputeTypeA,
208  MPerXdl,
209  NPerXdl,
210  ComputeTypeB,
212  is_scale_mfma>;
213  static constexpr index_t KPack = math::max(
215 
216  // static constexpr index_t NumTokens = 1;
217  static constexpr index_t SortedTileSize = MPerBlock;
218 
219  static constexpr auto MakeDsGridPointer()
220  {
221  return generate_tuple(
222  [&](auto i) {
223  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
224 
225  return static_cast<const DDataType*>(nullptr);
226  },
228  }
229 
230  using DsGridPointer = decltype(MakeDsGridPointer());
231 
233 
234  __host__ static auto CalculateGridSize(index_t M, index_t N)
235  {
236  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
237  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
238  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
239  const index_t gridy = NSwizzle ? 1 : mblock;
240 
241  return std::make_tuple(gridx, gridy, 1);
242  }
243 
244  __host__ static auto CalculateMPadded(index_t M)
245  {
246  return math::integer_least_multiple(M, MPerBlock);
247  }
248 
249  __host__ static auto CalculateNPadded(index_t N)
250  {
251  return math::integer_least_multiple(N, NPerBlock);
252  }
253 
254  __host__ static auto CalculateKPadded(index_t K)
255  {
256  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
257  }
258 
259  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
260  {
261  auto K_t = K_Batch * KPerBlock;
262  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
263  }
264 
265  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
266  {
267  auto K_t = K_Batch * KPerBlock;
268  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
269  }
270 
271  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
272  {
273  auto K_t = K_Batch * KPerBlock;
274  return (K + K_t - 1) / K_t * KPerBlock;
275  }
276 
277  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
278  {
279  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
280  auto K_t = K_Batch * KReadVec;
281  return (K + K_t - 1) / K_t * KReadVec;
282  }
283 
284  __host__ static auto CalculateMBlock(index_t M)
285  {
286  return math::integer_divide_ceil(M, MPerBlock);
287  }
288 
289  __host__ static auto CalculateNBlock(index_t N)
290  {
291  return math::integer_divide_ceil(N, NPerBlock);
292  }
293 
294  template <index_t MNXdlPerWave,
295  index_t MNWaves,
296  index_t MNXdlPack,
297  index_t MNPerXdl,
298  typename TileDesc_K0_MN_K1>
299  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
300  {
301  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
302  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
303 
305  TileDesc_K0_MN_K1{},
308  Number<MNWaves>{},
310  Number<MNPerXdl>{}))),
313  }
314 
315  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
316  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
317  {
318  const auto a_grid_desc_mraw_kraw = [&]() {
319  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
320  {
321  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
322  }
323  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
324  {
325  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
326  }
327  }();
328 
330 
331  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
332  GemmSpec == GemmSpecialization::MNKPadding)
333  {
334  // pad both M and K
335  const auto a_grid_desc_m_k =
336  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
338  make_right_pad_transform(K, KPad - K)),
341 
342  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
343  a_grid_desc_m_k,
348 
349  return a_grid_desc_ak0_m_ak1;
350  }
351  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
352  GemmSpec == GemmSpecialization::MNPadding)
353  {
354  // pad M, but not K
355  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
356  a_grid_desc_mraw_kraw,
358  make_right_pad_transform(M, MPad - M)),
361 
362  return a_grid_desc_ak0_m_ak1;
363  }
364  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
365  GemmSpec == GemmSpecialization::NKPadding)
366  {
367  // pad K, but not M
368  const auto a_grid_desc_m_k = transform_tensor_descriptor(
369  a_grid_desc_mraw_kraw,
373 
374  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
375  a_grid_desc_m_k,
380 
381  return a_grid_desc_ak0_m_ak1;
382  }
383  else
384  {
385  // not pad M or K
386  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
387  a_grid_desc_mraw_kraw,
392 
393  return a_grid_desc_ak0_m_ak1;
394  }
395  }
396 
397  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
398  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
399  {
400  const auto b_grid_desc_nraw_kraw = [&]() {
402  {
403  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
404  }
406  {
407  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
408  }
409  }();
410 
412 
413  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
414  GemmSpec != GemmSpecialization::Default),
415  "pk_i4_t does not support padding");
416  static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
417  GemmSpec != GemmSpecialization::Default),
418  "f4x2_pk_t does not support padding");
419 
420  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
421  GemmSpec == GemmSpecialization::MNKPadding)
422  {
423  // pad both N and K
424  const auto b_grid_desc_n_k =
425  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
427  make_right_pad_transform(K, KPad - K)),
430 
431  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
432  b_grid_desc_n_k,
437 
438  return b_grid_desc_bk0_n_bk1;
439  }
440  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
441  GemmSpec == GemmSpecialization::MNPadding)
442  {
443  // pad N, but not K
444  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
445  b_grid_desc_nraw_kraw,
447  make_right_pad_transform(N, NPad - N)),
450 
451  return b_grid_desc_bk0_n_bk1;
452  }
453  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
454  GemmSpec == GemmSpecialization::MKPadding)
455  {
456  // pad K, but not N
457  const auto b_grid_desc_n_k = transform_tensor_descriptor(
458  b_grid_desc_nraw_kraw,
462 
463  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
464  b_grid_desc_n_k,
469 
470  return b_grid_desc_bk0_n_bk1;
471  }
472  else
473  {
474  // not pad N or K
475  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
476  b_grid_desc_nraw_kraw,
481 
482  return b_grid_desc_bk0_n_bk1;
483  }
484  }
485 
486  template <typename ABlockDesc_AK0_M_AK1>
487  __host__ __device__ static constexpr auto
488  MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
489  {
490  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
491 
492  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl>(
493  ABlockDesc_AK0_M_AK1{});
494  }
495 
496  template <typename BBlockDesc_BK0_N_BK1>
497  __host__ __device__ static constexpr auto
498  MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
499  {
500  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
501 
502  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl>(
503  BBlockDesc_BK0_N_BK1{});
504  }
505 
506  template <typename ELayout>
507  __host__ __device__ static auto MakeCGridDescriptor_M_N(
508  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
509  {
510  const auto c_grid_desc_mraw_nraw = [&]() {
512  {
513  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
514  }
516  {
517  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
518  }
519  }();
520 
521  // pad M and N
522  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
524  make_right_pad_transform(N, NPad - N)),
527  }
528 
529  template <typename DLayout>
530  __host__ __device__ static auto
532  {
533  const auto c_grid_desc_mraw_nraw = [&]() {
535  {
536  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
537  }
539  {
540  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
541  }
542  }();
543 
544  // pad M and N
545  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
547  make_right_pad_transform(N, NPad - N)),
550  }
551 
552  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
553  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
554  {
555  return generate_tuple(
556  [&](auto i) {
557  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
558  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
559  },
561  }
562 
563  template <typename DsGridDesc>
565  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
566  {
567  return generate_tuple(
568  [&](auto i) {
570  ds_grid_desc_m_n[i], MBlock, NBlock);
571  },
573  }
574 
575  struct Problem
576  {
577  __host__ Problem(index_t NumTokens_,
578  index_t TopK_,
579  index_t M_,
580  index_t N_,
581  index_t K_,
582  index_t StrideA_,
583  index_t StrideScaleA_,
584  index_t StrideB_,
585  index_t StrideScaleB_,
586  std::array<index_t, NumDTensor> StrideDs_,
587  index_t StrideC_,
588  index_t KBatch_)
589  : NumTokens{NumTokens_},
590  TopK{TopK_},
591  M{M_},
592  N{N_},
593  K{K_},
594  StrideA{StrideA_},
595  StrideScaleA{StrideScaleA_},
596  StrideB{StrideB_},
597  StrideScaleB{StrideScaleB_},
598  StrideDs{StrideDs_},
599  StrideC{StrideC_},
600  KBatch{KBatch_},
603  KRead{CalculateKRead(K_, KBatch_)},
604  KPadded{CalculateKPadded(K_, KBatch_)},
605  AK0{CalculateAK0Padded(K_, KBatch_)},
606  BK0{CalculateBK0Padded(K_, KBatch_)},
607  MBlock{CalculateMBlock(M_)},
609  {
610  }
611 
612  __host__ void Print() const
613  {
614  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
615  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
616  << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
617  << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
618  << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
619  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
620  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
621  << ", " << "NBlock: " << NBlock << "}" << std::endl;
622  }
623 
633  std::array<index_t, NumDTensor> StrideDs;
644  };
645 
646  // Argument
648  {
649  __host__ Argument(const index_t* p_sorted_token_ids_,
650  const index_t* p_sorted_expert_ids_,
651  const index_t* p_max_token_id_,
652  const ADataType* p_a_grid_,
653  const AScaleDataType* p_a_scale_grid_,
654  const BDataType* p_b_grid_,
655  const BScaleDataType* p_b_scale_grid_,
656  std::array<const void*, NumDTensor> p_ds_grid_,
657  CDataType* p_c_grid_,
658  index_t NumTokens_,
659  index_t TopK_,
660  index_t M_,
661  index_t N_,
662  index_t K_,
663  index_t StrideA_,
664  index_t StrideScaleA_,
665  index_t StrideB_,
666  index_t StrideScaleB_,
667  std::array<index_t, NumDTensor> StrideDs_,
668  index_t StrideC_,
669  index_t k_batch_,
670  AElementwiseOperation a_element_op_,
671  BElementwiseOperation b_element_op_,
672  CElementwiseOperation c_element_op_)
673  : Problem{NumTokens_,
674  TopK_,
675  M_,
676  N_,
677  K_ / APackedSize,
678  StrideA_ / APackedSize,
679  StrideScaleA_,
680  StrideB_ / BPackedSize,
681  StrideScaleB_,
682  StrideDs_,
683  StrideC_,
684  k_batch_},
685  p_sorted_token_ids{p_sorted_token_ids_},
686  p_sorted_expert_ids{p_sorted_expert_ids_},
687  p_max_token_id{p_max_token_id_},
688  p_a_grid{p_a_grid_},
689  p_a_scale_grid{p_a_scale_grid_},
690  p_b_grid{p_b_grid_},
691  p_b_scale_grid{p_b_scale_grid_},
692  p_ds_grid{},
693  p_c_grid{p_c_grid_},
694  a_element_op{a_element_op_},
695  b_element_op{b_element_op_},
696  c_element_op{c_element_op_}
697  {
698 
699  // populate pointer, desc for Ds
700  static_for<0, NumDTensor, 1>{}([&](auto i) {
701  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
702 
703  // D pointer
704  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
705  });
706  }
707 
711  const ADataType* p_a_grid;
712  const AScaleDataType* p_a_scale_grid;
713  const BDataType* p_b_grid;
714  const BScaleDataType* p_b_scale_grid;
716  CDataType* p_c_grid;
717 
718  const AElementwiseOperation a_element_op;
719  const BElementwiseOperation b_element_op;
720  const CElementwiseOperation c_element_op;
721  };
722 
724  {
725  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
726  {
727  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
728  {
729  a_k_split_offset = k_id * karg.KRead;
730  }
731  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
732  {
733  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
734  }
735 
736  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
737  {
738  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
739  }
740  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
741  {
742  // KPack * NLane * KLane * K0 * N0
743  b_k_split_offset = k_id * karg.KRead;
744  }
745 
746  // Calculate A scale offset
747  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
748  {
749  a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
750  }
751  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
752  {
754  k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
755  }
756 
757  // Calculate B scale offset
758  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
759  {
761  k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
762  }
763  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
764  {
765  b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
766  }
767 
768  if(k_id < karg.KBatch - 1)
769  {
770  karg.K = karg.KRead;
771  }
772  else
773  {
774  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
775  }
776  }
777 
782  };
783 
784  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
785  {
786  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
787  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
788  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
789 
790  // A matrix in LDS memory, dst of blockwise copy
791  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
792  {
796  }
797  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
798  // in some cases.
800  {
801  constexpr auto a_lds_block_desc =
804 
805  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
806  a_lds_block_desc,
812 
813  return a_lds_block_desc_permuted;
814  }
815  else // ColumnMajor A
816  {
817  // kfold and mpair dimension is not always required.
818  // more dimension in merge_transform increase the difficulty of generating immarg offset
819  // for compiler.
820  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
821  constexpr auto M1 = MPerBlock / M0;
822 
823  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
824  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
825  constexpr auto KThreadRead = WaveSize / MPerXdl;
826  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
827 
828  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
829  ? 1
830  : 128 / (AK1Number * M0 * sizeof(ADataType));
831  constexpr auto KThreadReadPerm =
832  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
833  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
834  : KThreadRead;
835 
836  // 1<=mpair<=n0
837  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
838  ? 1
839  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
840  ? M0
841  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
842 
843  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
847  Number<kfold * M0 / mpair>{},
848  Number<mpair>{},
849  AK1Number));
850 
851  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
852  a_lds_block_desc,
853  make_tuple(
857  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
860  make_tuple(
862  make_tuple(
864 
865  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
866  a_lds_block_desc_permuted,
867  make_tuple(
875  Sequence<1>{},
876  Sequence<2>{},
877  Sequence<3>{},
878  Sequence<4>{},
879  Sequence<5>{}),
881  Sequence<2>{},
882  Sequence<0, 3>{},
883  Sequence<4, 5>{},
884  Sequence<6>{},
885  Sequence<7>{}));
886 
887  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
888  a_lds_block_desc_unmerged,
891  Number<KThreadWrite / kfold / KThreadReadPerm>{},
892  Number<kfold>{},
899 
900  return a_lds_block_desc_ak0_m_ak1;
901  }
902  }
903 
904  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
905  {
906  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
907  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
908  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
909 
910  // B matrix in LDS memory, dst of blockwise copy
911  if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
912  {
916  }
918  {
919  // NLdsLayer * K0 as logical Bank
920  constexpr auto b_lds_block_desc =
923 
924  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
925  b_lds_block_desc,
931 
932  return b_lds_block_desc_permuted;
933  }
934  else // RowMajor B
935  {
936  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
937  constexpr auto N1 = NPerBlock / N0;
938 
939  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
940  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
941  constexpr auto KThreadRead = WaveSize / NPerXdl;
942  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
943 
944  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
945  ? 1
946  : 128 / (BK1Number * N0 * sizeof(BDataType));
947  constexpr auto KThreadReadPerm =
948  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
949  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
950  : KThreadRead;
951 
952  // 1<=npair<=n0
953  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
954  ? 1
955  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
956  ? N0
957  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
958 
959  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
963  Number<kfold * N0 / npair>{},
964  Number<npair>{},
965  BK1Number));
966 
967  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
968  b_lds_block_desc,
969  make_tuple(
973  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
976  make_tuple(
978  make_tuple(
980 
981  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
982  b_lds_block_desc_permuted,
983  make_tuple(
991  Sequence<1>{},
992  Sequence<2>{},
993  Sequence<3>{},
994  Sequence<4>{},
995  Sequence<5>{}),
997  Sequence<2>{},
998  Sequence<0, 3>{},
999  Sequence<4, 5>{},
1000  Sequence<6>{},
1001  Sequence<7>{}));
1002 
1003  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1004  b_lds_block_desc_unmerged,
1007  Number<KThreadWrite / kfold / KThreadReadPerm>{},
1008  Number<kfold>{},
1015 
1016  return b_lds_block_desc_bk0_n_bk1;
1017  }
1018  }
1019 
1021  {
1022  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1023  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1024 
1025  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1027  make_tuple(I1,
1029  I1,
1031 
1032  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1033  }
1034 
1037  BlkGemmPipelineVer,
1038  BlkGemmPipeSched,
1039  BlockSize,
1040  ScaleBlockSize,
1041  ADataType,
1042  AScaleDataType,
1043  BDataType,
1044  BScaleDataType,
1045  ComputeTypeA,
1046  AccDataType,
1053  ABlockTransferSrcScalarPerVector,
1054  BBlockTransferSrcScalarPerVector,
1055  MPerBlock,
1056  NPerBlock,
1057  KPerBlock,
1058  MPerXdl,
1059  NPerXdl,
1060  MXdlPerWave,
1061  NXdlPerWave,
1062  KPack,
1063  IsInputGemm>())>;
1064 
1065  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1066  {
1067  // LDS allocation for A and B: be careful of alignment
1068  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1069  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1070 
1071  // lds max alignment
1072  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1073 
1074  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1075  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1076 
1077  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1078  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1079 
1080  // LDS allocation for C shuffle in LDS
1081  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1083 
1084  constexpr auto c_block_size =
1085  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1086 
1087  if constexpr(IsInputGemm)
1088  {
1089  return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1090  b_block_space_size_aligned * sizeof(BDataType)) *
1091  2,
1092  c_block_size * sizeof(CShuffleDataType));
1093  }
1094  else
1095  {
1096  return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1097  b_block_space_size_aligned * sizeof(BDataType)),
1098  c_block_size * sizeof(CShuffleDataType));
1099  }
1100  }
1101 
1102  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1103  __host__ static constexpr bool CheckValidity(const Argument& karg)
1104  {
1105  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1106  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1107  "Invalid tuning param!");
1108 
1109  static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1110  "KPerBlock should be multiple of ScaleBlockSize");
1111 
1112  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1117  {
1118  if(!(karg.M % MPerBlock == 0))
1119  {
1120  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1121  {
1122  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1123  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1124  << std::endl;
1125  }
1126  return false;
1127  }
1128  }
1129 
1130  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1135  {
1136  if(!(karg.N % NPerBlock == 0))
1137  {
1138  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1139  {
1140  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1141  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1142  << std::endl;
1143  }
1144  return false;
1145  }
1146  }
1147 
1148  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1152  {
1153  auto K_t = karg.KBatch * KPerBlock;
1154  if(!(karg.K % K_t == 0))
1155  {
1156  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1157  {
1158  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1159  << karg.K << " " << __FILE__ << ":" << __LINE__
1160  << ", in function: " << __func__ << std::endl;
1161  }
1162  return false;
1163  }
1164  }
1165  else
1166  {
1167  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1168  auto K_t = karg.KBatch * KReadVec;
1169  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1170  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1171  {
1172  return false;
1173  }
1174  }
1175 
1177  {
1178  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1179  {
1180  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1181  {
1182  std::cout << "Arg K (" << karg.K
1183  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1184  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1185  << __LINE__ << ", in function: " << __func__ << std::endl;
1186  }
1187  return false;
1188  }
1189  }
1190  else
1191  {
1192  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1193  {
1194  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1195  {
1196  std::cout << "Arg M (" << karg.M
1197  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1198  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1199  << __LINE__ << ", in function: " << __func__ << std::endl;
1200  }
1201  return false;
1202  }
1203  }
1204 
1206  {
1207  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1208  {
1209  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1210  {
1211  std::cout << "Arg N (" << karg.N
1212  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1213  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1214  << __LINE__ << ", in function: " << __func__ << std::endl;
1215  }
1216  return false;
1217  }
1218  }
1219  else
1220  {
1221  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1222  {
1223  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1224  {
1225  std::cout << "Arg K (" << karg.K
1226  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1227  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1228  << __LINE__ << ", in function: " << __func__ << std::endl;
1229  }
1230  return false;
1231  }
1232  }
1233 
1235  {
1237  {
1238  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1239  {
1240  std::cout << "Arg N (" << karg.N
1241  << ") value is not a multiple of "
1242  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1244  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1245  << std::endl;
1246  }
1247  return false;
1248  }
1249  }
1250  else
1251  {
1253  {
1254  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1255  {
1256  std::cout << "Arg M (" << karg.M
1257  << ") value is not a multiple of "
1258  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1260  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1261  << std::endl;
1262 
1263  return false;
1264  }
1265  }
1266  }
1267 
1268  // check gridwise gemm pipeline
1269 #if 0
1270  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1271 
1272  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1273  {
1274  return false;
1275  }
1276 #endif
1277  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1278  return true;
1279  }
1280 
1281  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1282  {
1283  const index_t num_loop = K / KPerBlock;
1284 
1285  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1286  }
1287 
1288  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1289  {
1290  const index_t num_loop = K / KPerBlock;
1291 
1292  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1293  }
1294 
1295  template <typename CGridDesc>
1296  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1297  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1298  {
1299  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1300  c_grid_desc_m_n,
1305 
1306  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1307  }
1308 
1309  // return block_id to C matrix tile idx (m0, n0) mapping
1310  // if arch = gfx942
1311  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1312  // NPerBlock>;
1313 
1315  static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
1316  static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
1317  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
1318  "A scale pack data type too large!");
1319  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
1320  "B scale pack data type too large!");
1321 
1322  template <bool HasMainKBlockLoop,
1323  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1324  TailNumber TailNum = TailNumber::Odd>
1325  __device__ static void Run(const index_t* p_sorted_token_ids,
1326  const index_t* p_sorted_expert_ids,
1327  const index_t* p_max_token_id,
1328  const ADataType* p_a_grid,
1329  const AScaleDataType* p_a_scale_grid,
1330  const BDataType* p_b_grid,
1331  const BScaleDataType* p_b_scale_grid,
1332  DsGridPointer& p_ds_grid,
1333  CDataType* p_c_grid,
1334  void* p_shared,
1335  const Problem& problem,
1336  AElementwiseOperation a_element_op,
1337  BElementwiseOperation b_element_op,
1338  CElementwiseOperation c_element_op)
1339  {
1340  ignore = b_element_op;
1341  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1342  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1343  problem.MPadded,
1344  problem.K,
1345  problem.KPadded,
1346  problem.StrideA,
1347  problem.AK0);
1348  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1349  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1350  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1351  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1352  problem.MPadded,
1353  problem.N,
1354  problem.NPadded,
1355  problem.StrideC);
1356 
1357  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1358  make_tuple(problem.M / (MXdlPack * MPerXdl),
1359  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1360  (KXdlPack * 64 / MPerXdl),
1361  64 * KXdlPack * MXdlPack / scale_pack_size_a));
1362 
1363  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1364  make_tuple(problem.N / (NXdlPack * NPerXdl),
1365  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1366  (KXdlPack * 64 / NPerXdl),
1367  64 * KXdlPack * NXdlPack / scale_pack_size_b));
1368 
1369  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1371  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1372 
1373  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1374  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1375  if(expert_block_id * MPerBlock >= max_token_id)
1376  return;
1377  const index_t expert_id =
1378  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1379 
1380  const auto block_mn = [&]() -> std::pair<int, int> {
1381  if constexpr(NSwizzle)
1382  {
1383  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1384  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1385  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1386  const index_t expert_swizzle =
1387  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1388  const index_t bid_new = blockIdx.x - prefix_block;
1389  const index_t nid = __builtin_amdgcn_readfirstlane(
1390  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1391  const index_t mid =
1392  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1393  return {nid, mid};
1394  }
1395  else
1396  {
1397  return {blockIdx.x, blockIdx.y};
1398  }
1399  }();
1400 
1401  const index_t block_n_id = block_mn.first;
1402  const index_t block_m_id = block_mn.second;
1403  const index_t token0 =
1404  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1405 
1406  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1407  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1408  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1409  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1410  constexpr auto AKThreads = AK0Threads * AK1Threads;
1411  constexpr auto AMRepeats = MPerBlock / AMThreads;
1412  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1413 
1414  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1415  return;
1417  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1418  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1419  index_t token_offset = fused_token & 0xffffff;
1420  if constexpr(!IsInputGemm)
1421  {
1422  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1423  }
1424  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1425  });
1426 
1427  const index_t expert_stride =
1428  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1429  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1430  problem.N * (IsInputGemm ? 2 : 1) *
1431  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
1432 
1433  // N0, K0, Blocksize*KPack
1434  const index_t n_block_data_idx_on_grid =
1435  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1436 
1437  // Gride buffer creation
1438  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1439  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1440  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1441  p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1442 
1443  // A, B scale buffer
1444  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1445  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1446  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1447  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
1448  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1449 
1450  // lds max alignment
1451  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1452 
1453  // A matrix in LDS memory, dst of blockwise copy
1454  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1455 
1456  // B matrix in LDS memory, dst of blockwise copy
1457  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1458 
1459  // A matrix blockwise copy
1460  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1462  AElementwiseOperation,
1466  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1467  ABlockTransferThreadClusterArrangeOrder,
1468  ADataType,
1469  ADataType,
1470  decltype(a_grid_desc_ak0_m_ak1),
1471  decltype(a_block_desc_ak0_m_ak1),
1472  ABlockTransferSrcAccessOrder,
1474  ABlockTransferSrcVectorDim,
1475  2,
1476  ABlockTransferSrcScalarPerVector,
1477  ABlockTransferDstScalarPerVector_AK1,
1478  1,
1479  1,
1480  AThreadTransferSrcResetCoordinateAfterRun,
1481  true,
1482  IndexType,
1483  1,
1484  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1485  make_multi_index(0, 0, 0),
1486  a_element_op,
1487  a_block_desc_ak0_m_ak1,
1488  make_multi_index(0, 0, 0),
1490  gather_offsets);
1491 
1492  // B matrix blockwise copy
1493  auto b_blockwise_copy =
1495  BElementwiseOperation,
1499  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1500  BBlockTransferThreadClusterArrangeOrder,
1501  BDataType,
1502  BDataType,
1503  decltype(b_grid_desc_bk0_n_bk1),
1504  decltype(b_block_desc_bk0_n_bk1),
1505  BBlockTransferSrcAccessOrder,
1507  BBlockTransferSrcVectorDim,
1508  2,
1509  BBlockTransferSrcScalarPerVector,
1510  BBlockTransferDstScalarPerVector_BK1,
1511  1,
1512  1,
1513  BThreadTransferSrcResetCoordinateAfterRun,
1514  true,
1515  BlockwiseGemmPipe::GlobalBufferNum>(
1516  b_grid_desc_bk0_n_bk1,
1517  make_multi_index(0, n_block_data_idx_on_grid, 0),
1518  b_element_op,
1519  b_block_desc_bk0_n_bk1,
1520  make_multi_index(0, 0, 0),
1522 
1523  // LDS allocation for A and B: be careful of alignment
1524  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1525  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1526 
1527  // Cast after lds
1528  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1529  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1530 
1531  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1532  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1533  a_block_space_size_aligned * sizeof(ADataType)),
1534  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1535 
1536  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1537  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1538 
1539  // Blockwise GEMM pipeline
1540  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1541  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1542  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1543  decltype(c_thread_buf) c_thread_buf_up;
1544 
1546  float,
1547  c_thread_buf.num_of_v_,
1548  c_thread_buf.s_per_v,
1549  true>
1550  c_thread_buf_fp32;
1551 
1552  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1553  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1554  KPerBlock);
1555 
1556  // a and b scale processing
1557  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1558  const auto waveId_m = wave_idx[I0];
1559  const auto waveId_n = wave_idx[I1];
1560 
1561  auto thread_offset_shuffled =
1562  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1563 
1564  auto a_thread_offset_m = waveId_m;
1565 
1566  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1567  AScaleDataType,
1568  AScaleDataType,
1569  decltype(a_scale_grid_desc_am_ak),
1570  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1571  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1572  Sequence<0, 1, 2>, // DimAccessOrder
1573  2, // SrcVectorDim
1574  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1575  1, // SrcScalarStrideInVector
1576  true>(a_scale_grid_desc_am_ak,
1577  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1578  0,
1579  thread_offset_shuffled / scale_pack_size_a));
1580 
1581  // B scale load
1582  auto b_thread_offset_n = waveId_n;
1583 
1584  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1585  BScaleDataType,
1586  BScaleDataType,
1587  decltype(b_scale_grid_desc_bn_ak),
1588  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1589  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1590  Sequence<0, 1, 2>, // DimAccessOrder
1591  2, // SrcVectorDim
1592  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1593  1, // SrcScalarStrideInVector
1594  true>(b_scale_grid_desc_bn_ak,
1595  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1596  0,
1597  thread_offset_shuffled / scale_pack_size_b));
1598 
1599  if constexpr(IsInputGemm)
1600  {
1601  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1602  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1603  auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1604  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1605  a_block_space_size_aligned * sizeof(ADataType) +
1606  b_block_space_size_aligned * sizeof(BDataType)),
1607  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1608 
1609  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1610  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1611  p_b_grid_up + expert_id * expert_stride,
1612  b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1613 
1614  auto b_blockwise_copy_up =
1616  BElementwiseOperation,
1620  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1621  BBlockTransferThreadClusterArrangeOrder,
1622  BDataType,
1623  BDataType,
1624  decltype(b_grid_desc_bk0_n_bk1),
1625  decltype(b_block_desc_bk0_n_bk1),
1626  BBlockTransferSrcAccessOrder,
1628  BBlockTransferSrcVectorDim,
1629  2,
1630  BBlockTransferSrcScalarPerVector,
1631  BBlockTransferDstScalarPerVector_BK1,
1632  1,
1633  1,
1634  BThreadTransferSrcResetCoordinateAfterRun,
1635  true,
1636  BlockwiseGemmPipe::GlobalBufferNum>(
1637  b_grid_desc_bk0_n_bk1,
1638  make_multi_index(0, n_block_data_idx_on_grid, 0),
1639  b_element_op,
1640  b_block_desc_bk0_n_bk1,
1641  make_multi_index(0, 0, 0),
1643 
1644  const BScaleDataType* p_b_scale_grid_up =
1645  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
1646  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1647  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
1648  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1649 
1650  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1651  BScaleDataType,
1652  BScaleDataType,
1653  decltype(b_scale_grid_desc_bn_ak),
1654  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1655  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1656  Sequence<0, 1, 2>, // DimAccessOrder
1657  2, // SrcVectorDim
1658  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1659  1, // SrcScalarStrideInVector
1660  true>(
1661  b_scale_grid_desc_bn_ak,
1662  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1663  0,
1664  thread_offset_shuffled / scale_pack_size_b));
1665 
1666  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1667  // A
1668  a_grid_desc_ak0_m_ak1,
1669  a_block_desc_ak0_m_ak1,
1670  a_blockwise_copy,
1671  a_grid_buf,
1672  a_block_buf,
1673  a_block_slice_copy_step,
1674  // Gate and Up
1675  b_grid_desc_bk0_n_bk1,
1676  b_block_desc_bk0_n_bk1,
1677  b_blockwise_copy,
1678  b_blockwise_copy_up,
1679  b_grid_buf,
1680  b_grid_buf_up,
1681  b_block_buf,
1682  b_block_buf_up,
1683  b_block_slice_copy_step,
1684  // C
1685  c_thread_buf,
1686  c_thread_buf_up,
1687  // A scale
1688  a_scale_grid_desc_am_ak,
1689  a_scale_thread_copy,
1690  a_scale_grid_buf,
1691  // Gate and Up scale
1692  b_scale_grid_desc_bn_ak,
1693  b_scale_thread_copy,
1694  b_scale_thread_copy_up,
1695  b_scale_grid_buf,
1696  b_scale_grid_buf_up,
1697  num_k_block_main_loop);
1698  }
1699  else
1700  {
1701  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1702  a_grid_desc_ak0_m_ak1, // A
1703  a_block_desc_ak0_m_ak1,
1704  a_blockwise_copy,
1705  a_grid_buf,
1706  a_block_buf,
1707  a_block_slice_copy_step,
1708  b_grid_desc_bk0_n_bk1, // B
1709  b_block_desc_bk0_n_bk1,
1710  b_blockwise_copy,
1711  b_grid_buf,
1712  b_block_buf,
1713  b_block_slice_copy_step,
1714  c_thread_buf, // C
1715  a_scale_grid_desc_am_ak, // A scale
1716  a_scale_thread_copy,
1717  a_scale_grid_buf,
1718  b_scale_grid_desc_bn_ak, // B scale
1719  b_scale_thread_copy,
1720  b_scale_grid_buf,
1721  num_k_block_main_loop);
1722  }
1723 
1724  // shuffle C and write out
1725  {
1726  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1727  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1728  "wrong!");
1729  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1730  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1731  "wrong!");
1732 
1733  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1734  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1735 
1736  // TODO: hacky, fix it!
1737  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1738  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1739 
1740  // TODO: hacky, fix it!
1741  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1742  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1743  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1744 
1745  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1746  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1747  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1748  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1749  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1750  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1751  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1752  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1753  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1754  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1755 
1756  // mul scales
1757  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1758  static_assert(M5 == 4);
1759  const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id
1760  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
1761 
1762  vector_type<float, 4> topk_weights; // for gemm2 only
1763  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
1764  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
1765  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
1766  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
1767  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
1768  const index_t m_pos = block_m_id * MPerBlock +
1769  m0 * M2 * M1 * M3 * M4 * M5 +
1770  m1 * M2 * M3 * M4 * M5 +
1771  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1772 
1773  if constexpr(MulRoutedWeight)
1774  {
1775  topk_weights =
1776  *c_style_pointer_cast<const vector_type<float, M5>*>(
1777  p_ds_grid[I2] + m_pos);
1778  }
1779  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
1780  constexpr index_t c_offset =
1781  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1782  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1783  constexpr auto cidx = Number<c_offset>{};
1784 
1785  if constexpr(IsInputGemm) // gu fusion
1786  {
1787  if constexpr(ActivationOperation ==
1789  {
1790  float gate = c_thread_buf[cidx];
1791  float up = c_thread_buf_up[cidx];
1792  if constexpr(MulRoutedWeight)
1793  {
1794  gate = gate * topk_weights.AsType<float>()[m5];
1795  up = up * topk_weights.AsType<float>()[m5];
1796  }
1798  c_thread_buf_fp32(cidx) = gate * up;
1799  }
1800  else if(ActivationOperation == Activation::gelu_and_mul)
1801  {
1802  float gate = c_thread_buf[cidx];
1803  float up = c_thread_buf_up[cidx];
1804  if constexpr(MulRoutedWeight)
1805  {
1806  gate = gate * topk_weights.AsType<float>()[m5];
1807  up = up * topk_weights.AsType<float>()[m5];
1808  }
1810  c_thread_buf_fp32(cidx) = gate * up;
1811 
1812  /*float gate = c_thread_buf[cidx];
1813  float up = c_thread_buf_up[cidx];
1814  if constexpr(MulRoutedWeight)
1815  {
1816  gate = gate * topk_weights.AsType<float>()[m5];
1817  //up = up * topk_weights.AsType<float>()[m5];
1818  }
1819  tensor_operation::element_wise::Gelu{}(gate, gate);
1820  c_thread_buf_fp32(cidx) = up;*/
1821  }
1822  }
1823  else
1824  {
1825  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1826  if constexpr(MulRoutedWeight)
1827  {
1828  c_thread_buf_fp32(cidx) =
1829  topk_weights.AsType<float>()[m5] *
1830  c_thread_buf_fp32[cidx];
1831  }
1832  }
1833  });
1834  });
1835  });
1836  });
1837  });
1838  });
1839 
1840  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1842 
1843  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1844  static_cast<CShuffleDataType*>(p_shared),
1845  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1846 
1847  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1848  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1849  make_tuple(
1853  // per shuffle
1854  M1, // M1 = MWave
1855  M2, // M2 = MXdlPack
1856  M3, // M3 * M4 * M5 = MPerXdl
1857  M4,
1858  M5)),
1862  // per shuffle
1863  N1, // N1 = NWave
1864  N2, // N2 = NXdlPack
1865  N3))), // N3 = NPerXdl
1869  Sequence<>{},
1871 
1872  // calculate origin of thread output tensor on global memory
1873  // blockwise GEMM c matrix starting index
1874  const auto c_thread_mtx_on_block =
1875  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1876 
1877  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1878  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1879 
1880  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1882  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1884  make_tuple(Sequence<0>{}));
1885 
1886  const auto m_thread_data_on_block_idx =
1887  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1888  make_multi_index(m_thread_data_on_block));
1889 
1890  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1892  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1894  make_tuple(Sequence<0>{}));
1895 
1896  const auto n_thread_data_on_block_idx =
1897  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1898  make_multi_index(n_thread_data_on_block));
1899 
1900  // shuffle: threadwise copy C from VGPR to LDS
1901  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1902  AccDataType,
1903  CShuffleDataType,
1904  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1905  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1907  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1908  CShuffleNXdlPerWavePerShuffle / NXdlPack,
1909  I1,
1910  I1,
1911  M2,
1912  N2,
1913  M3,
1914  I1,
1915  M5,
1916  I1>,
1918  9,
1919  1,
1921  1,
1922  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1923  make_multi_index(0,
1924  0,
1925  m_thread_data_on_block_idx[I1],
1926  n_thread_data_on_block_idx[I1],
1927  m_thread_data_on_block_idx[I2],
1928  n_thread_data_on_block_idx[I2],
1929  m_thread_data_on_block_idx[I3],
1930  m_thread_data_on_block_idx[I4],
1931  m_thread_data_on_block_idx[I5],
1932  n_thread_data_on_block_idx[I3]),
1934 
1935  using EDataType = CDataType;
1936 
1937  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1938  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1939 
1940  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1942  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1943 
1944  const auto ds_grid_buf = generate_tuple(
1945  [&](auto i) {
1946  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1947  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1948  },
1949  Number<NumDTensor>{});
1950 
1951  // tuple of reference to C/Ds tensor descriptors
1952  const auto c_ds_desc_refs = concat_tuple_of_reference(
1953  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1954  generate_tie([&](auto i) -> const auto& // return type should be reference
1955  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1956  Number<NumDTensor>{}));
1957 
1958  // tuple of reference to C/Ds tensor descriptors
1959  const auto c_ds_buf_refs = concat_tuple_of_reference(
1960  tie(c_shuffle_block_buf),
1961  generate_tie([&](auto i) -> const auto& // return type should be reference
1962  { return ds_grid_buf[i]; },
1963  Number<NumDTensor>{}));
1964 
1965  // tuple of starting index of C/Ds blockwise copy
1966  const auto idx_c_ds_block_begin =
1969  [&](auto) {
1970  return make_multi_index(block_m_id, 0, block_n_id, 0);
1971  // return make_multi_index(block_work_idx[I0], 0,
1972  // block_work_idx[I1], 0);
1973  },
1974  Number<NumDTensor>{}));
1975 
1976  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1977  c_grid_desc_mblock_mperblock_nblock_nperblock;
1978 
1979  using CDEBlockTransferCluster =
1980  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1981  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1982  constexpr index_t scatter_weight_idx = 3; // hack fix felix
1983  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1985  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1987  decltype(c_ds_desc_refs),
1988  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1989  CElementwiseOperation,
1990  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
1991  // Sequence support
1992  // arbitray type
1993  Sequence<1,
1994  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1995  1,
1996  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1997  CDEBlockTransferCluster,
1998  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1999  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2000  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2001  3, // index_t SrcVectorDim,
2002  3, // index_t DstVectorDim,
2003  CDEShuffleBlockTransferScalarPerVectors,
2008  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2009  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2010  IndexType,
2011  1, // ScatterDim
2012  true, // OutputScatter: false, only use scatter weights
2013  scatter_weight_idx // ScatterWeightIdx: ascale
2014  >{c_ds_desc_refs,
2015  idx_c_ds_block_begin,
2016  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2017  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2018  c_element_op};
2019 
2020  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2021  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2022 
2023  constexpr auto sfc_c_vgpr =
2024  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2025  NXdlPerWave / NXdlPack,
2026  1,
2027  1,
2028  MXdlPack,
2029  NXdlPack,
2030  M2,
2031  1,
2032  M4,
2033  1>,
2035  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2036  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2037  1,
2038  1,
2039  MXdlPack,
2040  NXdlPack,
2041  M2,
2042  1,
2043  M4,
2044  1>>{};
2045 
2046  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2047 
2048  // space filling curve for shuffled blockwise C/D/E
2049  constexpr auto sfc_cde_block =
2052  Sequence<1,
2053  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2054  1,
2055  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2056 
2057  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2058  constexpr auto EMThreads =
2059  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2060  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2061  constexpr auto ENThreads =
2062  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2063  static_for<0, num_access, 1>{}([&](auto access_id) {
2064  // make sure it's safe to write to LDS
2066 
2067  auto dstidx = sfc_cde_block.GetIndex(access_id);
2068  const index_t c_token_pos =
2069  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2070  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2071  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2072  IndexType token_offset = fused_token & 0xffffff;
2073  if constexpr(IsInputGemm)
2074  {
2075  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2076  }
2077  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2078  });
2079 
2080  block_sync_lds();
2081 
2082  // each thread write its data from VGPR to LDS
2083  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2084  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2085  c_thread_buf_fp32,
2086  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2087  c_shuffle_block_buf);
2088 
2089  // make sure it's safe to read from LDS
2090  block_sync_lds();
2091 
2092  // each block copy its data from LDS to global
2093  cde_block_copy_lds_and_global.Run(
2094  c_ds_desc_refs,
2095  c_ds_buf_refs,
2096  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2097  tie(c_grid_buf),
2098  scatter_offsets);
2099 
2100  if constexpr(access_id < num_access - 1)
2101  {
2102  constexpr auto cde_lds_and_global_step =
2103  sfc_cde_block.GetForwardStep(access_id);
2104 
2105  // move on Ds
2106  static_for<0, NumDTensor, 1>{}([&](auto i) {
2107  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2108  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2109  });
2110 
2111  // move on E
2112  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2113  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2114  I0,
2115  cde_lds_and_global_step);
2116  }
2117  });
2118  }
2119  }
2120 
2121 #if 0
2122  template <bool HasMainKBlockLoop,
2123  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2124  TailNumber TailNum = TailNumber::Odd>
2125  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
2126  const index_t* p_sorted_expert_ids,
2127  const index_t* p_max_token_id,
2128  const ADataType* p_a_grid,
2129  const AScaleDataType* p_a_scale_grid,
2130  const BDataType* p_b_grid,
2131  const BScaleDataType* p_b_scale_grid,
2132  DsGridPointer& p_ds_grid,
2133  CDataType* p_c_grid,
2134  void* p_shared,
2135  void* p_shared1,
2136  const Problem& problem,
2137  AElementwiseOperation a_element_op,
2138  BElementwiseOperation b_element_op,
2139  CElementwiseOperation c_element_op)
2140  {
2141  ignore = b_element_op;
2142  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2143  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2144  problem.MPadded,
2145  problem.K,
2146  problem.KPadded,
2147  problem.StrideA,
2148  problem.AK0);
2149  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2150  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2151  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2152  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2153  problem.MPadded,
2154  problem.N,
2155  problem.NPadded,
2156  problem.StrideC);
2157 
2158  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
2159  make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerXdl),
2160  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2161  (KXdlPack * 64 / MPerXdl),
2162  64 * KXdlPack * MXdlPack / scale_pack_size_a));
2163 
2164  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
2165  make_tuple(problem.N / (NXdlPack * NPerXdl),
2166  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2167  (KXdlPack * 64 / NPerXdl),
2168  64 * KXdlPack * NXdlPack / scale_pack_size_b));
2169 
2170  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2172  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2173  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2174  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
2175  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2176  if(expert_block_id * MPerBlock >= max_token_id)
2177  return;
2178  const index_t expert_id =
2179  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2180  const auto block_mn = [&]() -> std::pair<int, int> {
2181  if constexpr(NSwizzle)
2182  {
2183  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2184  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2185  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2186  const index_t expert_swizzle =
2187  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2188  const index_t bid_new = blockIdx.x - prefix_block;
2189  const index_t nid = __builtin_amdgcn_readfirstlane(
2190  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2191  const index_t mid =
2192  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2193  return {nid, mid};
2194  }
2195  else
2196  {
2197  return {blockIdx.x, blockIdx.y};
2198  }
2199  }();
2200 
2201  const index_t block_n_id = block_mn.first;
2202  const index_t block_m_id = block_mn.second;
2203  const index_t token0 =
2204  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2205 
2206  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2207  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2208  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2209  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2210  constexpr auto AKThreads = AK0Threads * AK1Threads;
2211  constexpr auto AMRepeats = MPerBlock / AMThreads;
2212  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2213 
2214  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2215  return;
2216  StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
2217  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2218  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2219  index_t token_offset = fused_token & 0xffffff;
2220  if constexpr(!IsInputGemm)
2221  {
2222  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2223  }
2224  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2225  });
2226 
2227  const index_t expert_stride =
2228  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2229  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2230  problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2231 
2232  // N0, K0, Blocksize*KPack
2233  const index_t n_block_data_idx_on_grid =
2234  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2235 
2236  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2237  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2238 
2239  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2240  p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2241 
2242  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2243  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2244  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2245  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2246  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2247 
2248  // A matrix in LDS memory, dst of blockwise copy
2249  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2250 
2251  // B matrix in LDS memory, dst of blockwise copy
2252  // dummy
2253  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2254  // A matrix blockwise copy
2255  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2257  AElementwiseOperation,
2260  Sequence<AK0Number, MPerBlock, AK1Number>,
2261  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2262  ABlockTransferThreadClusterArrangeOrder,
2263  ADataType,
2264  LDSTypeA,
2265  decltype(a_grid_desc_ak0_m_ak1),
2266  decltype(a_block_desc_ak0_m_ak1),
2267  ABlockTransferSrcAccessOrder,
2268  Sequence<0, 1, 2>,
2269  ABlockTransferSrcVectorDim,
2270  2,
2271  ABlockTransferSrcScalarPerVector,
2272  ABlockTransferDstScalarPerVector_AK1,
2273  1,
2274  1,
2275  AThreadTransferSrcResetCoordinateAfterRun,
2276  true,
2277  IndexType,
2278  1,
2279  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2280  make_multi_index(0, 0, 0),
2281  a_element_op,
2282  a_block_desc_ak0_m_ak1,
2283  make_multi_index(0, 0, 0),
2285  gather_offsets);
2286 
2287  // Thread-wise copy
2288  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2289  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2290  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2291  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2292  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2293  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2294 
2295  auto b_blockwise_copy =
2296  ThreadwiseTensorSliceTransfer_v2<BDataType,
2297  BDataType,
2298  decltype(b_grid_desc_bpreshuffled),
2299  decltype(b_block_desc_bk0_n_bk1),
2300  Sequence<Number<NXdlPerWave / NXdlPack>{},
2301  I1,
2302  Number<NXdlPack>{},
2303  Number<KRepeat>{},
2304  Number<BK1Value>{}>,
2305  Sequence<1, 2, 0, 3, 4>,
2306  4,
2307  BBlockTransferSrcScalarPerVector,
2308  BThreadTransferSrcResetCoordinateAfterRun,
2309  true>(
2310  b_grid_desc_bpreshuffled,
2311  make_multi_index(n_block_data_idx_on_grid,
2312  get_warp_local_1d_id() % NWave,
2313  0,
2314  0,
2315  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2316 
2317  // LDS allocation for A and B: be careful of alignment
2318  // Cast after lds
2319  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2320  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2321  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2322  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2323  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2324 
2325  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2326  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
2327 
2328  // Blockwise GEMM pipeline
2329  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2330  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2331  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2332  decltype(c_thread_buf) c_thread_buf_up;
2333 
2334  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
2335  float,
2336  c_thread_buf.num_of_v_,
2337  c_thread_buf.s_per_v,
2338  true>
2339  c_thread_buf_fp32;
2340 
2341  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2342  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2343  KPerBlock);
2344 
2345  // a and b scale processing
2346  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2347  const auto waveId_m = wave_idx[I0];
2348  const auto waveId_n = wave_idx[I1];
2349 
2350  auto thread_offset_shuffled =
2351  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2352 
2353  auto a_thread_offset_m = waveId_m;
2354 
2355  // get each thread's offset int the scale tensor
2356  const index_t token_scale_pos = block_m_id * MPerBlock;
2357  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2358  return;
2359 
2360  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2361  AScaleDataType,
2362  AScaleDataType,
2363  decltype(a_scale_grid_desc_am_ak),
2364  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2365  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2366  Sequence<0, 1, 2>, // DimAccessOrder
2367  2, // SrcVectorDim
2368  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2369  1, // SrcScalarStrideInVector
2370  true>(a_scale_grid_desc_am_ak,
2371  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2372  0,
2373  thread_offset_shuffled / scale_pack_size_a));
2374 
2375  // B scale load
2376  auto b_thread_offset_n = waveId_n;
2377 
2378  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2379  BScaleDataType,
2380  BScaleDataType,
2381  decltype(b_scale_grid_desc_bn_ak),
2382  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2383  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2384  Sequence<0, 1, 2>, // DimAccessOrder
2385  2, // SrcVectorDim
2386  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2387  1, // SrcScalarStrideInVector
2388  true>(b_scale_grid_desc_bn_ak,
2389  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2390  0,
2391  thread_offset_shuffled / scale_pack_size_b));
2392 
2393  if constexpr(IsInputGemm)
2394  {
2395  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2396  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2397  p_b_grid_up + expert_id * expert_stride / BPackedSize,
2398  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2399  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2400  BDataType,
2401  BDataType,
2402  decltype(b_grid_desc_bpreshuffled),
2403  decltype(b_block_desc_bk0_n_bk1),
2404  Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
2405  Sequence<1, 2, 0, 3>,
2406  3,
2407  BBlockTransferSrcScalarPerVector,
2408  BThreadTransferSrcResetCoordinateAfterRun,
2409  true>(b_grid_desc_bpreshuffled,
2410  make_multi_index(n_block_data_idx_on_grid,
2411  get_warp_local_1d_id() % NWave,
2412  0,
2413  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2414  const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
2415  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2416  p_b_scale_grid_up + expert_id * expert_scale_stride,
2417  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2418  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2419  BScaleDataType,
2420  BScaleDataType,
2421  decltype(b_scale_grid_desc_bn_ak),
2422  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2423  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2424  Sequence<0, 1, 2>, // DimAccessOrder
2425  2, // SrcVectorDim
2426  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2427  1, // SrcScalarStrideInVector
2428  true>(
2429  b_scale_grid_desc_bn_ak,
2430  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2431  0,
2432  thread_offset_shuffled / scale_pack_size_b));
2433 
2434  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2435  a_grid_desc_ak0_m_ak1,
2436  a_block_desc_ak0_m_ak1,
2437  a_blockwise_copy,
2438  a_grid_buf,
2439  a_block_bufs,
2440  a_block_slice_copy_step,
2441  b_grid_desc_bpreshuffled,
2442  b_block_desc_bk0_n_bk1,
2443  b_blockwise_copy,
2444  b_blockwise_copy_up,
2445  b_grid_buf,
2446  b_grid_buf_up,
2447  b_block_bufs,
2448  b_block_slice_copy_step,
2449  c_thread_buf,
2450  c_thread_buf_up,
2451  a_scale_grid_desc_am_ak,
2452  a_scale_thread_copy,
2453  a_scale_grid_buf,
2454  b_scale_grid_desc_bn_ak,
2455  b_scale_thread_copy,
2456  b_scale_thread_copy_up,
2457  b_scale_grid_buf,
2458  b_scale_grid_buf_up,
2459  num_k_block_main_loop);
2460  }
2461  else
2462  {
2463  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2464  a_grid_desc_ak0_m_ak1,
2465  a_block_desc_ak0_m_ak1,
2466  a_blockwise_copy,
2467  a_grid_buf,
2468  a_block_bufs,
2469  a_block_slice_copy_step,
2470  b_grid_desc_bpreshuffled,
2471  b_block_desc_bk0_n_bk1,
2472  b_blockwise_copy,
2473  b_grid_buf,
2474  b_block_bufs,
2475  b_block_slice_copy_step,
2476  c_thread_buf,
2477  a_scale_grid_desc_am_ak,
2478  a_scale_thread_copy,
2479  a_scale_grid_buf,
2480  b_scale_grid_desc_bn_ak,
2481  b_scale_thread_copy,
2482  b_scale_grid_buf,
2483  num_k_block_main_loop);
2484  }
2485 
2486  // shuffle C and write out
2487  {
2488  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2489  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2490  "wrong!");
2491 
2492  // TODO: hacky, fix it!
2493  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2494  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2495 
2496  // TODO: hacky, fix it!
2497  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2498  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2499  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2500 
2501  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2502  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2503  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2504  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2505  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2506  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2507  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2508  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2509 
2510  // mul scales
2511 
2512  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2513  static_assert(M4 == 4);
2514  const index_t m1 = get_warp_local_1d_id() / NWave;
2515  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
2516 
2517  vector_type<float, 4> topk_weights; // for gemm2 only
2518  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2519  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2520  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2521  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2522  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2523  if constexpr(MulRoutedWeight)
2524  {
2525  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2526  p_ds_grid[I2] + m_pos);
2527  }
2528  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2529  constexpr index_t c_offset =
2530  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2531  make_tuple(m0 / MXdlPack,
2532  n0 / NXdlPack,
2533  m0 % MXdlPack,
2534  n0 % NXdlPack,
2535  m2 * M4 + m4));
2536  constexpr auto cidx = Number<c_offset>{};
2537 
2538  if constexpr(IsInputGemm) // gu fusion
2539  {
2540  if constexpr(ActivationOperation == Activation::silu_and_mul)
2541  {
2542  float gate = c_thread_buf[cidx];
2543  float up = c_thread_buf_up[cidx];
2544  if constexpr(MulRoutedWeight)
2545  {
2546  gate = gate * topk_weights.AsType<float>()[m4];
2547  up = up * topk_weights.AsType<float>()[m4];
2548  }
2549  tensor_operation::element_wise::Silu{}(gate, gate);
2550  c_thread_buf_fp32(cidx) = gate * up;
2551  }
2552  else if(ActivationOperation == Activation::gelu_and_mul)
2553  {
2554  float gate = c_thread_buf[cidx];
2555  float up = c_thread_buf_up[cidx];
2556  if constexpr(MulRoutedWeight)
2557  {
2558  gate = gate * topk_weights.AsType<float>()[m4];
2559  up = up * topk_weights.AsType<float>()[m4];
2560  }
2561  tensor_operation::element_wise::Gelu{}(gate, gate);
2562  c_thread_buf_fp32(cidx) = gate * up;
2563  }
2564  }
2565  else
2566  {
2567  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2568  if constexpr(MulRoutedWeight)
2569  {
2570  c_thread_buf_fp32(cidx) =
2571  topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
2572  }
2573  }
2574  });
2575  });
2576  });
2577  });
2578 
2579  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2581 
2582  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2583  static_cast<CShuffleDataType*>(p_shared),
2584  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2585 
2586  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2587  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2590  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per
2591  // shuffle
2592  M1, // M1 = MWave
2593  M2, // M2 * M3 * M4 = MPerXdl
2594  M3,
2595  M4)),
2598  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per
2599  // shuffle
2600  N1, // N1 = NWave
2601  N2))), // N2 = NPerXdl
2602  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
2603  make_tuple(
2604  Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
2605 
2606  // calculate origin of thread output tensor on global memory
2607  // blockwise GEMM c matrix starting index
2608  const auto c_thread_mtx_on_block =
2609  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2610 
2611  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2612  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2613 
2614  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2616  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2617  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
2618  make_tuple(Sequence<0>{}));
2619 
2620  const auto m_thread_data_on_block_idx =
2621  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2622  make_multi_index(m_thread_data_on_block));
2623 
2624  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2627  make_tuple(Sequence<0, 1, 2>{}),
2628  make_tuple(Sequence<0>{}));
2629 
2630  const auto n_thread_data_on_block_idx =
2631  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2632  make_multi_index(n_thread_data_on_block));
2633 
2634  // shuffle: threadwise copy C from VGPR to LDS
2635  auto c_thread_copy_vgpr_to_lds =
2636  ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
2637  CShuffleDataType,
2638  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2639  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2641  Sequence<CShuffleMXdlPerWavePerShuffle,
2642  CShuffleNXdlPerWavePerShuffle,
2643  I1,
2644  I1,
2645  M2,
2646  I1,
2647  M4,
2648  I1>,
2649  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2650  7,
2651  1,
2653  1,
2654  true>{
2655  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2656  make_multi_index(0,
2657  0,
2658  m_thread_data_on_block_idx[I1],
2659  n_thread_data_on_block_idx[I1],
2660  m_thread_data_on_block_idx[I2],
2661  m_thread_data_on_block_idx[I3],
2662  m_thread_data_on_block_idx[I4],
2663  n_thread_data_on_block_idx[I2]),
2665 
2666  using EDataType = CDataType;
2667 
2668  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2669  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2670 
2671  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2673  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2674 
2675  const auto ds_grid_buf = generate_tuple(
2676  [&](auto i) {
2677  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2678  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2679  },
2680  Number<NumDTensor>{});
2681 
2682  // tuple of reference to C/Ds tensor descriptors
2683  const auto c_ds_desc_refs = concat_tuple_of_reference(
2684  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2685  generate_tie([&](auto i) -> const auto& // return type should be reference
2686  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2687  Number<NumDTensor>{}));
2688 
2689  // tuple of reference to C/Ds tensor descriptors
2690  const auto c_ds_buf_refs = concat_tuple_of_reference(
2691  tie(c_shuffle_block_buf),
2692  generate_tie([&](auto i) -> const auto& // return type should be reference
2693  { return ds_grid_buf[i]; },
2694  Number<NumDTensor>{}));
2695 
2696  // tuple of starting index of C/Ds blockwise copy
2697  const auto idx_c_ds_block_begin =
2700  [&](auto) {
2701  return make_multi_index(block_m_id, 0, block_n_id, 0);
2702  // return make_multi_index(block_work_idx[I0], 0,
2703  // block_work_idx[I1], 0);
2704  },
2705  Number<NumDTensor>{}));
2706 
2707  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2708  c_grid_desc_mblock_mperblock_nblock_nperblock;
2709 
2710  using CDEBlockTransferCluster =
2711  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2712  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2713  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2714  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2716  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2717  Tuple<EDataType>,
2718  decltype(c_ds_desc_refs),
2719  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2720  CElementwiseOperation,
2721  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2722  // Sequence support
2723  // arbitray type
2724  Sequence<1,
2725  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2726  1,
2727  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2728  CDEBlockTransferCluster,
2729  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2730  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2731  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2732  3, // index_t SrcVectorDim,
2733  3, // index_t DstVectorDim,
2734  CDEShuffleBlockTransferScalarPerVectors,
2737  Sequence<true>,
2739  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2740  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2741  IndexType,
2742  1, // ScatterDim
2743  true, // OutputScatter: false, only use scatter weights
2744  scatter_weight_idx // ScatterWeightIdx: ascale
2745  >{c_ds_desc_refs,
2746  idx_c_ds_block_begin,
2747  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2748  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2749  c_element_op};
2750 
2751  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2752  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2753  constexpr auto sfc_c_vgpr =
2754  SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
2755  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2756  Sequence<CShuffleMXdlPerWavePerShuffle,
2757  CShuffleNXdlPerWavePerShuffle,
2758  1,
2759  1,
2760  M2,
2761  1,
2762  M4,
2763  1>>{};
2764 
2765  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2766 
2767  // space filling curve for shuffled blockwise C/D/E
2768  constexpr auto sfc_cde_block =
2769  SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
2770  Sequence<0, 2, 1, 3>,
2771  Sequence<1,
2772  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2773  1,
2774  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2775 
2776  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2777  constexpr auto EMThreads =
2778  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2779  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2780  constexpr auto ENThreads =
2781  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2782  static_for<0, num_access, 1>{}([&](auto access_id) {
2783  // make sure it's safe to write to LDS
2784  StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
2785 
2786  auto dstidx = sfc_cde_block.GetIndex(access_id);
2787  const index_t c_token_pos =
2788  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2789  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2790  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2791  IndexType token_offset = fused_token & 0xffffff;
2792  if constexpr(IsInputGemm)
2793  {
2794  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2795  }
2796  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2797  });
2798 
2799  block_sync_lds();
2800 
2801  // each thread write its data from VGPR to LDS
2802  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2803  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2804  c_thread_buf_fp32,
2805  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2806  c_shuffle_block_buf);
2807 
2808  // make sure it's safe to read from LDS
2809  block_sync_lds();
2810 
2811  // each block copy its data from LDS to global
2812  cde_block_copy_lds_and_global.Run(
2813  c_ds_desc_refs,
2814  c_ds_buf_refs,
2815  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2816  tie(c_grid_buf),
2817  scatter_offsets);
2818 
2819  if constexpr(access_id < num_access - 1)
2820  {
2821  constexpr auto cde_lds_and_global_step =
2822  sfc_cde_block.GetForwardStep(access_id);
2823 
2824  // move on Ds
2825  static_for<0, NumDTensor, 1>{}([&](auto i) {
2826  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2827  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2828  });
2829 
2830  // move on E
2831  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2832  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2833  I0,
2834  cde_lds_and_global_step);
2835  }
2836  });
2837  }
2838  }
2839 #endif
2840 };
2841 
2842 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:56
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
constexpr auto BlockGemmMXNBSPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp:37
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: gridwise_moe_mx_gemm_bns.hpp:648
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:711
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm_bns.hpp:708
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm_bns.hpp:709
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm_bns.hpp:710
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:715
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:720
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:716
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:718
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:714
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:713
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:712
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm_bns.hpp:649
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:719
Definition: gridwise_moe_mx_gemm_bns.hpp:576
index_t M
Definition: gridwise_moe_mx_gemm_bns.hpp:626
index_t TopK
Definition: gridwise_moe_mx_gemm_bns.hpp:625
index_t NPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:637
index_t MPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:636
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm_bns.hpp:632
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm_bns.hpp:630
index_t MBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:642
index_t StrideC
Definition: gridwise_moe_mx_gemm_bns.hpp:634
index_t AK0
Definition: gridwise_moe_mx_gemm_bns.hpp:640
index_t KPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:639
index_t NBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:643
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm_bns.hpp:577
index_t StrideA
Definition: gridwise_moe_mx_gemm_bns.hpp:629
index_t StrideB
Definition: gridwise_moe_mx_gemm_bns.hpp:631
index_t KBatch
Definition: gridwise_moe_mx_gemm_bns.hpp:635
index_t BK0
Definition: gridwise_moe_mx_gemm_bns.hpp:641
index_t KRead
Definition: gridwise_moe_mx_gemm_bns.hpp:638
__host__ void Print() const
Definition: gridwise_moe_mx_gemm_bns.hpp:612
index_t K
Definition: gridwise_moe_mx_gemm_bns.hpp:628
index_t N
Definition: gridwise_moe_mx_gemm_bns.hpp:627
index_t NumTokens
Definition: gridwise_moe_mx_gemm_bns.hpp:624
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm_bns.hpp:633
Definition: gridwise_moe_mx_gemm_bns.hpp:724
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:778
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:779
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm_bns.hpp:725
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:781
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:780
Definition: gridwise_moe_mx_gemm_bns.hpp:173
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm_bns.hpp:206
static constexpr auto I1
Definition: gridwise_moe_mx_gemm_bns.hpp:178
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm_bns.hpp:1315
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm_bns.hpp:196
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm_bns.hpp:244
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:198
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm_bns.hpp:192
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm_bns.hpp:194
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm_bns.hpp:174
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:1281
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm_bns.hpp:205
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:232
static constexpr auto I5
Definition: gridwise_moe_mx_gemm_bns.hpp:182
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bns.hpp:397
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm_bns.hpp:175
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:1288
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm_bns.hpp:230
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bns.hpp:315
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:1103
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bns.hpp:552
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bns.hpp:1325
static constexpr auto I4
Definition: gridwise_moe_mx_gemm_bns.hpp:181
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm_bns.hpp:1065
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm_bns.hpp:202
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm_bns.hpp:219
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bns.hpp:564
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:249
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:259
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:498
static constexpr auto I8
Definition: gridwise_moe_mx_gemm_bns.hpp:185
static constexpr auto I7
Definition: gridwise_moe_mx_gemm_bns.hpp:184
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:200
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bns.hpp:1296
static constexpr auto I2
Definition: gridwise_moe_mx_gemm_bns.hpp:179
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:488
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm_bns.hpp:1316
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:277
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:299
remove_cvref_t< decltype(BlockGemmMXNBSPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm_bns.hpp:1063
static constexpr auto I6
Definition: gridwise_moe_mx_gemm_bns.hpp:183
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm_bns.hpp:213
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:265
static constexpr auto I3
Definition: gridwise_moe_mx_gemm_bns.hpp:180
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm_bns.hpp:193
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm_bns.hpp:904
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm_bns.hpp:217
static constexpr auto I9
Definition: gridwise_moe_mx_gemm_bns.hpp:186
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm_bns.hpp:531
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm_bns.hpp:191
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm_bns.hpp:1020
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm_bns.hpp:284
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:188
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:254
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:199
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:289
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm_bns.hpp:203
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm_bns.hpp:507
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:271
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm_bns.hpp:784
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:234
static constexpr auto I0
Definition: gridwise_moe_mx_gemm_bns.hpp:177
Definition: xdlops_gemm.hpp:1126
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1647
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: data_type.hpp:41
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129