/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File
gridwise_moe_gemm_blockscale.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
18 
19 #define DEBUG_LOG 0
20 
21 namespace ck {
22 
23 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24 // kernel function Blockers:
25 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26 // two lds chunks.
27 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28 // buffer when we declare __shared__ inside blkgemmpipe
29 
31 {
32  gelu_and_mul = 0,
33  silu_and_mul = 1
34 };
35 
36 template <typename GridwiseGemm,
37  bool HasMainKBlockLoop,
38  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
39  index_t MinimumOccupancy = 1,
40  TailNumber TailNum = TailNumber::Even>
41 __global__ void
42 #if CK_USE_LAUNCH_BOUNDS
43 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
44 #endif
45  // __attribute__((amdgpu_waves_per_eu(1, 1)))
46  kernel_moe_gemm(typename GridwiseGemm::Argument karg)
47 {
48 #if defined(__gfx9__)
49  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
50 
51  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
52 
53  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
54  karg.p_sorted_token_ids,
55  karg.p_sorted_expert_ids,
56  karg.p_max_token_id,
57  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
58  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
59  karg.p_ds_grid,
60  karg.p_c_grid,
61  karg.p_a_scale_grid,
62  karg.p_b_scale_grid,
63  p_shared,
64  karg,
65  karg.a_element_op,
66  karg.b_element_op,
67  karg.c_element_op);
68 #else
69  ignore = karg;
70 #endif // end of if (defined(__gfx9__))
71 }
72 
73 template <typename GridwiseGemm,
74  bool HasMainKBlockLoop,
75  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
76  index_t MinimumOccupancy = 1,
77  TailNumber TailNum = TailNumber::Even>
78 __global__ void
79 #if CK_USE_LAUNCH_BOUNDS
80 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
81 #endif
82  // __attribute__((amdgpu_waves_per_eu(1, 1)))
83  kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
84 {
85 #if defined(__gfx9__)
86  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
87  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
88 
89  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
90 
91  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
92  karg.p_sorted_token_ids,
93  karg.p_sorted_expert_ids,
94  karg.p_max_token_id,
95  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
96  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
97  karg.p_ds_grid,
98  karg.p_c_grid,
99  karg.p_a_scale_grid,
100  karg.p_b_scale_grid,
101  p_shared,
102  p_shared1,
103  karg,
104  karg.a_element_op,
105  karg.b_element_op,
106  karg.c_element_op);
107 #else
108  ignore = karg;
109 #endif // end of if (defined(__gfx9__))
110 }
111 
112 template <typename ALayout,
113  typename BLayout,
114  typename DsLayout,
115  typename CLayout,
116  typename ADataType,
117  typename BDataType,
118  typename AccDataType,
119  typename CShuffleDataType,
120  typename DsDataType,
121  typename CDataType,
122  typename AElementwiseOperation,
123  typename BElementwiseOperation,
124  typename CElementwiseOperation,
126  index_t BlockSize,
127  index_t ScaleBlockM,
128  index_t ScaleBlockN,
129  index_t ScaleBlockK,
130  index_t MPerBlock,
131  index_t NPerBlock,
132  index_t KPerBlock,
133  index_t AK1Value,
134  index_t BK1Value,
135  index_t MPerXdl,
136  index_t NPerXdl,
137  index_t MXdlPerWave,
138  index_t NXdlPerWave,
139  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
140  typename ABlockTransferThreadClusterArrangeOrder,
141  typename ABlockTransferSrcAccessOrder,
142  index_t ABlockTransferSrcVectorDim,
143  index_t ABlockTransferSrcScalarPerVector,
144  index_t ABlockTransferDstScalarPerVector_AK1,
145  bool AThreadTransferSrcResetCoordinateAfterRun,
146  index_t ABlockLdsExtraM,
147  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
148  typename BBlockTransferThreadClusterArrangeOrder,
149  typename BBlockTransferSrcAccessOrder,
150  index_t BBlockTransferSrcVectorDim,
151  index_t BBlockTransferSrcScalarPerVector,
152  index_t BBlockTransferDstScalarPerVector_BK1,
153  bool BThreadTransferSrcResetCoordinateAfterRun,
154  index_t BBlockLdsExtraN,
155  index_t CShuffleMXdlPerWavePerShuffle,
156  index_t CShuffleNXdlPerWavePerShuffle,
157  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
158  typename CDEShuffleBlockTransferScalarPerVectors,
161  index_t ActivationOperation = 0,
162  bool NSwizzle = false,
163  bool IsInputGemm = true,
164  bool MulRoutedWeight = true,
165  typename IndexType = index_t,
166  typename ComputeTypeA = CDataType,
167  typename ComputeTypeB = ComputeTypeA,
168  typename LDSTypeA = ADataType,
169  typename LDSTypeB = BDataType>
171 {
172  using AScaleType = float;
173  using BScaleType = float;
174 
175  static constexpr auto I0 = Number<0>{};
176  static constexpr auto I1 = Number<1>{};
177  static constexpr auto I2 = Number<2>{};
178  static constexpr auto I3 = Number<3>{};
179  static constexpr auto I4 = Number<4>{};
180  static constexpr auto I5 = Number<5>{};
181  static constexpr auto I6 = Number<6>{};
182  static constexpr auto I7 = Number<7>{};
183 
185  CDEShuffleBlockTransferScalarPerVectors{}[I0];
186  // K1 should be Number<...>
187  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
188  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
189  static constexpr auto AK1Number = Number<AK1Value>{};
190  static constexpr auto BK1Number = Number<BK1Value>{};
191  static constexpr auto BlockSizeNumber = Number<BlockSize>{};
192 
193  static constexpr index_t NumDTensor = DsDataType::Size();
194 
196  static constexpr index_t KPack =
198  static constexpr index_t KGroup = []() {
200  // On gfx950, we have a mfma that required 32 f8 elements as input,
201  // splited into 2 groups of 16 f8 elements.
202  // the 2 groups is not contiguous in the B preshuffed layout.
203  // and we do not want it to be contiguous in the B preshuffled layout
204  // because a memory instruction can only read 16 f8 elements at a time.
205  return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
206  else
207  return 1;
208  }();
209  static constexpr index_t KLane =
211  static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
212  static constexpr index_t NLane = NPerXdl;
213  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
214  // static constexpr index_t NumTokens = 1;
215  static constexpr index_t SortedTileSize = MPerBlock;
216 
217  static constexpr auto MakeDsGridPointer()
218  {
219  return generate_tuple(
220  [&](auto i) {
221  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
222 
223  return static_cast<const DDataType*>(nullptr);
224  },
226  }
227 
228  using DsGridPointer = decltype(MakeDsGridPointer());
229 
231 
232  static constexpr index_t APackedSize = []() {
234  return 2;
235  else
236  return 1;
237  }();
238 
239  static constexpr index_t BPackedSize = []() {
241  return 2;
242  else
243  return 1;
244  }();
245 
246  __host__ static auto CalculateGridSize(index_t M, index_t N)
247  {
248  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
249  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
250  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
251  const index_t gridy = NSwizzle ? 1 : mblock;
252  return std::make_tuple(gridx, gridy, 1);
253  }
254 
255  __host__ __device__ static auto CalculateMPadded(index_t M)
256  {
257  return math::integer_least_multiple(M, MPerBlock);
258  }
259 
260  __host__ __device__ static auto CalculateNPadded(index_t N)
261  {
262  return math::integer_least_multiple(N, NPerBlock);
263  }
264 
265  __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
266  {
267  return math::integer_divide_ceil(N, NLane);
268  }
269  __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
270  {
272  }
273 
274  __host__ __device__ static auto CalculateKPadded(index_t K)
275  {
276  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
277  }
278 
279  __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
280  {
281  auto K_t = K_Batch * KPerBlock;
282  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
283  }
284 
285  __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
286  {
287  auto K_t = K_Batch * KPerBlock;
288  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
289  }
290 
291  __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
292  {
293  auto K_t = K_Batch * KPerBlock;
294  return (K + K_t - 1) / K_t * KPerBlock;
295  }
296 
297  __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
298  {
299  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
300  auto K_t = K_Batch * KReadVec;
301  return (K + K_t - 1) / K_t * KReadVec;
302  }
303 
304  __host__ __device__ static auto CalculateMBlock(index_t M)
305  {
306  return math::integer_divide_ceil(M, MPerBlock);
307  }
308 
309  __host__ __device__ static auto CalculateNBlock(index_t N)
310  {
311  return math::integer_divide_ceil(N, NPerBlock);
312  }
313 
314  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
315  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
316  {
317  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
318  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
319 
321  TileDesc_K0_MN_K1{},
327  }
328 
329  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
330  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
331  {
332  const auto a_grid_desc_mraw_kraw = [&]() {
333  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
334  {
335  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
336  }
337  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
338  {
339  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
340  }
341  }();
342 
344 
345  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
346  GemmSpec == GemmSpecialization::MNKPadding)
347  {
348  // pad both M and K
349  const auto a_grid_desc_m_k =
350  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
352  make_right_pad_transform(K, KPad - K)),
355 
356  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
357  a_grid_desc_m_k,
362 
363  return a_grid_desc_ak0_m_ak1;
364  }
365  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
366  GemmSpec == GemmSpecialization::MNPadding)
367  {
368  // pad M, but not K
369  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
370  a_grid_desc_mraw_kraw,
372  make_right_pad_transform(M, MPad - M)),
375 
376  return a_grid_desc_ak0_m_ak1;
377  }
378  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
379  GemmSpec == GemmSpecialization::NKPadding)
380  {
381  // pad K, but not M
382  const auto a_grid_desc_m_k = transform_tensor_descriptor(
383  a_grid_desc_mraw_kraw,
387 
388  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
389  a_grid_desc_m_k,
394 
395  return a_grid_desc_ak0_m_ak1;
396  }
397  else
398  {
399  // not pad M or K
400  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
401  a_grid_desc_mraw_kraw,
406 
407  return a_grid_desc_ak0_m_ak1;
408  }
409  }
410 
411  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
412  {
413  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
414  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
415  constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack / KGroup>{};
417  make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
418  make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
419  }
420 
421  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
422  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
423  {
424  const auto b_grid_desc_nraw_kraw = [&]() {
426  {
427  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
428  }
430  {
431  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
432  }
433  }();
434 
436 
437  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
438  GemmSpec != GemmSpecialization::Default),
439  "pk_i4_t does not support padding");
440 
441  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
442  GemmSpec == GemmSpecialization::MNKPadding)
443  {
444  // pad both N and K
445  const auto b_grid_desc_n_k =
446  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
448  make_right_pad_transform(K, KPad - K)),
451 
452  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
453  b_grid_desc_n_k,
458 
459  return b_grid_desc_bk0_n_bk1;
460  }
461  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
462  GemmSpec == GemmSpecialization::MNPadding)
463  {
464  // pad N, but not K
465  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
466  b_grid_desc_nraw_kraw,
468  make_right_pad_transform(N, NPad - N)),
471 
472  return b_grid_desc_bk0_n_bk1;
473  }
474  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
475  GemmSpec == GemmSpecialization::MKPadding)
476  {
477  // pad K, but not N
478  const auto b_grid_desc_n_k = transform_tensor_descriptor(
479  b_grid_desc_nraw_kraw,
483 
484  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
485  b_grid_desc_n_k,
490 
491  return b_grid_desc_bk0_n_bk1;
492  }
493  else
494  {
495  // not pad N or K
496  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
497  b_grid_desc_nraw_kraw,
502 
503  return b_grid_desc_bk0_n_bk1;
504  }
505  }
506 
507  template <typename ABlockDesc_AK0_M_AK1>
508  __host__ __device__ static constexpr auto
509  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
510  {
511  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
512 
513  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
514  }
515 
516  template <typename BBlockDesc_BK0_N_BK1>
517  __host__ __device__ static constexpr auto
518  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
519  {
520  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
521  }
522 
523  template <typename ELayout>
524  __host__ __device__ static auto MakeCGridDescriptor_M_N(
525  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
526  {
527  const auto c_grid_desc_mraw_nraw = [&]() {
529  {
530  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
531  }
533  {
534  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
535  }
536  }();
537 
538  // pad M and N
539  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
541  make_right_pad_transform(N, NPad - N)),
544  }
545 
546  template <typename DLayout>
547  __host__ __device__ static auto
549  {
550  const auto c_grid_desc_mraw_nraw = [&]() {
552  {
553  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
554  }
556  {
557  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
558  }
559  }();
560 
561  // pad M and N
562  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
564  make_right_pad_transform(N, NPad - N)),
567  }
568 
569  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
570  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
571  {
572  return generate_tuple(
573  [&](auto i) {
574  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
575  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
576  },
578  }
579 
580  template <typename DsGridDesc>
582  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
583  {
584  return generate_tuple(
585  [&](auto i) {
587  ds_grid_desc_m_n[i], MBlock, NBlock);
588  },
590  }
591 
592  using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
593 
594  struct Problem
595  {
596  __host__ __device__ Problem(index_t NumTokens_,
597  index_t TopK_,
598  index_t M_,
599  index_t N_,
600  index_t K_,
601  index_t StrideA_,
602  index_t StrideB_,
603  std::array<index_t, NumDTensor> StrideDs_,
604  index_t StrideC_,
605  index_t KBatch_)
606  : NumTokens{NumTokens_},
607  TopK{TopK_},
608  M{M_},
609  N{N_},
610  K{K_},
611  StrideA{StrideA_},
612  StrideB{StrideB_},
613  StrideDs{StrideDs_},
614  StrideC{StrideC_},
615  KBatch{KBatch_},
618  KRead{CalculateKRead(K_, KBatch_)},
619  KPadded{CalculateKPadded(K_, KBatch_)},
620  AK0{CalculateAK0Padded(K_, KBatch_)},
621  BK0{CalculateBK0Padded(K_, KBatch_)},
622  MBlock{CalculateMBlock(M_)},
624  {
625  }
626 
627  __host__ void Print() const
628  {
629  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
630  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
631  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
632  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
633  << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
634  << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
635  << "NBlock: " << NBlock << "}" << std::endl;
636  }
637 
645  std::array<index_t, NumDTensor> StrideDs;
656  };
657 
658  // Argument
660  {
661  __host__ Argument(const index_t* p_sorted_token_ids_,
662  const index_t* p_sorted_expert_ids_,
663  const index_t* p_max_token_id_,
664  const ADataType* p_a_grid_,
665  const BDataType* p_b_grid_,
666  std::array<const void*, NumDTensor> p_ds_grid_,
667  CDataType* p_c_grid_,
668  index_t NumTokens_,
669  index_t TopK_,
670  index_t M_,
671  index_t N_,
672  index_t K_,
673  index_t StrideA_,
674  index_t StrideB_,
675  std::array<index_t, NumDTensor> StrideDs_,
676  index_t StrideC_,
677  const AScaleType* p_a_scale_grid_,
678  const BScaleType* p_b_scale_grid_,
679  index_t k_batch_,
680  AElementwiseOperation a_element_op_,
681  BElementwiseOperation b_element_op_,
682  CElementwiseOperation c_element_op_)
683  : Problem{NumTokens_,
684  TopK_,
685  M_,
686  N_,
687  K_,
688  StrideA_,
689  StrideB_,
690  StrideDs_,
691  StrideC_,
692  k_batch_},
693  p_sorted_token_ids{p_sorted_token_ids_},
694  p_sorted_expert_ids{p_sorted_expert_ids_},
695  p_max_token_id{p_max_token_id_},
696  p_a_grid{p_a_grid_},
697  p_b_grid{p_b_grid_},
698  p_ds_grid{},
699  p_c_grid{p_c_grid_},
700  p_a_scale_grid{p_a_scale_grid_},
701  p_b_scale_grid{p_b_scale_grid_},
702  a_element_op{a_element_op_},
703  b_element_op{b_element_op_},
704  c_element_op{c_element_op_}
705  {
706 
707  // populate pointer, desc for Ds
708  static_for<0, NumDTensor, 1>{}([&](auto i) {
709  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
710 
711  // D pointer
712  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
713  });
714  }
715 
719  const ADataType* p_a_grid;
720  const BDataType* p_b_grid;
722  CDataType* p_c_grid;
723 
726 
727  const AElementwiseOperation a_element_op;
728  const BElementwiseOperation b_element_op;
729  const CElementwiseOperation c_element_op;
730  };
731 
733  {
734  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
735  {
736  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
737  {
738  a_k_split_offset = k_id * karg.KRead / APackedSize;
739  }
740  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
741  {
742  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
743  }
744 
745  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
746  {
747  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
748  }
749  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
750  {
751  // KPack * NLane * KLane * K0 * N0
752  b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
753  }
754 
755  if(k_id < karg.KBatch - 1)
756  {
757  karg.K = karg.KRead;
758  }
759  else
760  {
761  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
762  }
763  }
764 
767  };
768 
769  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
770  {
771  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
772  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
773  // A matrix in LDS memory, dst of blockwise copy
774  if constexpr(ABlockLdsExtraM)
775  {
779  }
780  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
781  // in some cases.
783  {
784  constexpr auto a_lds_block_desc =
787 
788  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
789  a_lds_block_desc,
795 
796  return a_lds_block_desc_permuted;
797  }
798  else // ColumnMajor A
799  {
800  // kfold and mpair dimension is not always required.
801  // more dimension in merge_transform increase the difficulty of generating immarg offset
802  // for compiler.
803  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
804  constexpr auto M1 = MPerBlock / M0;
805 
806  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
807  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
808  constexpr auto KThreadRead = WaveSize / MPerXdl;
809  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
810 
811  constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
812  ? 1
813  : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
814  constexpr auto KThreadReadPerm =
815  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
816  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
817  : KThreadRead;
818 
819  // 1<=mpair<=n0
820  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
821  ? 1
822  : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
823  ? M0
824  : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
825 
826  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
830  Number<kfold * M0 / mpair>{},
831  Number<mpair>{},
832  AK1Number));
833 
834  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
835  a_lds_block_desc,
836  make_tuple(
840  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
843  make_tuple(
845  make_tuple(
847 
848  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
849  a_lds_block_desc_permuted,
850  make_tuple(
858  Sequence<1>{},
859  Sequence<2>{},
860  Sequence<3>{},
861  Sequence<4>{},
862  Sequence<5>{}),
864  Sequence<2>{},
865  Sequence<0, 3>{},
866  Sequence<4, 5>{},
867  Sequence<6>{},
868  Sequence<7>{}));
869 
870  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
871  a_lds_block_desc_unmerged,
874  Number<KThreadWrite / kfold / KThreadReadPerm>{},
875  Number<kfold>{},
882 
883  return a_lds_block_desc_ak0_m_ak1;
884  }
885  }
886 
887  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
888  {
889  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
892  }
893 
895  {
896  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
897 
898  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
900  make_tuple(I1,
902  I1,
904 
905  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
906  }
907 
910  BlkGemmPipelineVer,
911  BlkGemmPipeSched,
912  BlockSize,
913  ADataType,
914  BDataType,
915  ComputeTypeA,
916  AccDataType,
923  ABlockTransferSrcScalarPerVector,
924  BBlockTransferSrcScalarPerVector,
925  MPerBlock,
926  NPerBlock,
927  KPerBlock,
928  ScaleBlockM,
929  ScaleBlockN,
930  ScaleBlockK,
931  MPerXdl,
932  NPerXdl,
933  MXdlPerWave,
934  NXdlPerWave,
935  KPack,
936  IsInputGemm>())>;
937 
938  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
939  {
940  // LDS allocation for A and B: be careful of alignment
941  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
942  // lds max alignment
943  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
944 
945  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
946  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
947 
948  // LDS allocation for C shuffle in LDS
949  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
951 
952  constexpr auto c_block_size =
953  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
954 
955  return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize,
956  c_block_size * sizeof(CShuffleDataType));
957  }
958 
959  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
960  __host__ static constexpr bool CheckValidity(const Argument& karg)
961  {
962  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
963  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
964  "Invalid tuning param!");
965 
966  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
971  {
972  if(!(karg.M % MPerBlock == 0))
973  {
974 #if DEBUG_LOG
975  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
976  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
977  << std::endl;
978 
979 #endif // DEBUG_LOG
980  return false;
981  }
982  }
983 
984  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
989  {
990  if(!(karg.N % NPerBlock == 0))
991  {
992 #if DEBUG_LOG
993  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
994  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
995  << std::endl;
996 
997 #endif // DEBUG_LOG
998  return false;
999  }
1000  }
1001 
1002  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1006  {
1007 
1008  auto K_t = karg.KBatch * KPerBlock;
1009  if(!(karg.K % K_t == 0))
1010  {
1011 #if DEBUG_LOG
1012  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1013  << karg.K << " " << __FILE__ << ":" << __LINE__
1014  << ", in function: " << __func__ << std::endl;
1015 
1016 #endif // DEBUG_LOG
1017  return false;
1018  }
1019  }
1020  else
1021  {
1022  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1023  auto K_t = karg.KBatch * KReadVec;
1024  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1025  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1026  {
1027  return false;
1028  }
1029  }
1030 
1032  {
1033  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1034  {
1035 #if DEBUG_LOG
1036  std::cout << "Arg K (" << karg.K
1037  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1038  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1039  << __LINE__ << ", in function: " << __func__ << std::endl;
1040 
1041 #endif // DEBUG_LOG
1042  return false;
1043  }
1044  }
1045  else
1046  {
1047  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1048  {
1049 #if DEBUG_LOG
1050  std::cout << "Arg M (" << karg.M
1051  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1052  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1053  << __LINE__ << ", in function: " << __func__ << std::endl;
1054 
1055 #endif // DEBUG_LOG
1056  return false;
1057  }
1058  }
1059 
1061  {
1062  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1063  {
1064 #if DEBUG_LOG
1065  std::cout << "Arg N (" << karg.N
1066  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1067  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1068  << __LINE__ << ", in function: " << __func__ << std::endl;
1069 
1070 #endif // DEBUG_LOG
1071  return false;
1072  }
1073  }
1074  else
1075  {
1076  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1077  {
1078 #if DEBUG_LOG
1079  std::cout << "Arg K (" << karg.K
1080  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1081  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1082  << __LINE__ << ", in function: " << __func__ << std::endl;
1083 
1084 #endif // DEBUG_LOG
1085  return false;
1086  }
1087  }
1088 
1090  {
1092  {
1093 #if DEBUG_LOG
1094  std::cout << "Arg N (" << karg.N
1095  << ") value is not a multiple of "
1096  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1097  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1098  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1099 
1100 #endif // DEBUG_LOG
1101  return false;
1102  }
1103  }
1104  else
1105  {
1107  {
1108 #if DEBUG_LOG
1109  std::cout << "Arg M (" << karg.M
1110  << ") value is not a multiple of "
1111  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1112  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1113  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1114 
1115 #endif // DEBUG_LOG
1116  return false;
1117  }
1118  }
1119 
1120  // check gridwise gemm pipeline
1121 #if 0
1122  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1123 
1124  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1125  {
1126  return false;
1127  }
1128 #endif
1129  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1130  return true;
1131  }
1132 
1133  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1134  {
1135  const index_t num_loop = K / KPerBlock;
1136 
1137  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1138  }
1139 
1140  __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1141  {
1142  const index_t num_loop = K / KPerBlock;
1143 
1144  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1145  }
1146 
1147  template <typename CGridDesc>
1149  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1150  {
1151  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1152  c_grid_desc_m_n,
1157 
1158  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1159  }
1160 
1161  // return block_id to C matrix tile idx (m0, n0) mapping
1162  // if arch = gfx942
1163  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1164  // NPerBlock>;
1165 
1166  template <bool HasMainKBlockLoop,
1167  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1168  TailNumber TailNum = TailNumber::Odd>
1169  __device__ static void Run(const index_t* p_sorted_token_ids,
1170  const index_t* p_sorted_expert_ids,
1171  const index_t* p_max_token_id,
1172  const ADataType* p_a_grid,
1173  const BDataType* p_b_grid,
1174  DsGridPointer& p_ds_grid,
1175  CDataType* p_c_grid,
1176  const AScaleType* p_a_scale_grid,
1177  const BScaleType* p_b_scale_grid,
1178  void* p_shared,
1179  const Problem& problem,
1180  AElementwiseOperation a_element_op,
1181  BElementwiseOperation b_element_op,
1182  CElementwiseOperation c_element_op)
1183  {
1184  ignore = b_element_op;
1185  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1186  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1187  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1188  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1189  problem.MPadded,
1190  problem.K,
1191  problem.KPadded,
1192  problem.StrideA,
1193  problem.AK0);
1194  const auto b_grid_desc_bpreshuffled =
1195  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1196  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1197  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1198  problem.MPadded,
1199  problem.N,
1200  problem.NPadded,
1201  problem.StrideC);
1202 
1203  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1204  make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
1205  : problem.NumTokens * problem.TopK,
1206  ScaleBlockM),
1207  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1208  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1209  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1210  make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1211  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1212  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1213 
1214  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1216  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1217  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1218  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1219  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1220  if(expert_block_id * MPerBlock >= max_token_id)
1221  return;
1222  const index_t expert_id =
1223  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1224  const auto block_mn = [&]() -> std::pair<int, int> {
1225  if constexpr(NSwizzle)
1226  {
1227  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1228  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1229  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1230  const index_t expert_swizzle =
1231  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1232  const index_t bid_new = blockIdx.x - prefix_block;
1233  const index_t nid = __builtin_amdgcn_readfirstlane(
1234  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1235  const index_t mid =
1236  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1237  return {nid, mid};
1238  }
1239  else
1240  {
1241  return {blockIdx.x, blockIdx.y};
1242  }
1243  }();
1244  const index_t block_n_id = block_mn.first;
1245  const index_t block_m_id = block_mn.second;
1246  const index_t token0 =
1247  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1248 
1249  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1250  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1251  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1252  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1253  constexpr auto AKThreads = AK0Threads * AK1Threads;
1254  constexpr auto AMRepeats = MPerBlock / AMThreads;
1255  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1256 
1257  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1258  return;
1260  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1261  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1262  index_t token_offset = fused_token & 0xffffff;
1263  if constexpr(!IsInputGemm)
1264  {
1265  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1266  }
1267  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1268  });
1269  const index_t expert_stride =
1270  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1271  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1272  math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
1273  math::integer_divide_ceil(problem.K, ScaleBlockK));
1274 
1275  // N0, K0, Blocksize*KPack
1276  const index_t n_block_data_idx_on_grid =
1277  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1278 
1279  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1280  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1281  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1282  p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1283  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1284 
1285  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1286  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1287  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1288  p_b_scale_grid + expert_id * expert_scale_stride,
1289  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1290 
1291  // A matrix in LDS memory, dst of blockwise copy
1292  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1293 
1294  // B matrix in LDS memory, dst of blockwise copy
1295  // dummy
1296  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1297  // A matrix blockwise copy
1298  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1300  AElementwiseOperation,
1304  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1305  ABlockTransferThreadClusterArrangeOrder,
1306  ADataType,
1307  LDSTypeA,
1308  decltype(a_grid_desc_ak0_m_ak1),
1309  decltype(a_block_desc_ak0_m_ak1),
1310  ABlockTransferSrcAccessOrder,
1312  ABlockTransferSrcVectorDim,
1313  2,
1314  ABlockTransferSrcScalarPerVector,
1315  ABlockTransferDstScalarPerVector_AK1,
1316  1,
1317  1,
1318  AThreadTransferSrcResetCoordinateAfterRun,
1319  true,
1320  IndexType,
1321  1,
1322  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1323  make_multi_index(0, 0, 0),
1324  a_element_op,
1325  a_block_desc_ak0_m_ak1,
1326  make_multi_index(0, 0, 0),
1328  gather_offsets);
1329 
1330  // Thread-wise copy
1331  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1332  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1333  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1334 
1335  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1336  BDataType,
1337  BDataType,
1338  decltype(b_grid_desc_bpreshuffled),
1339  decltype(b_block_desc_bk0_n_bk1),
1342  3,
1343  BBlockTransferSrcScalarPerVector,
1344  BThreadTransferSrcResetCoordinateAfterRun,
1345  true>(b_grid_desc_bpreshuffled,
1346  make_multi_index(n_block_data_idx_on_grid,
1348  0,
1349  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1350 
1351  // LDS allocation for A and B: be careful of alignment
1352  // Cast after lds
1353  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1354  static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1355 
1356  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1357  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1358 
1359  // Blockwise GEMM pipeline
1360  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1361  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1362  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1363  decltype(c_thread_buf) c_thread_buf_up;
1364 
1365  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1366  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1367  KPerBlock);
1368 
1369  constexpr index_t ScaleSliceSizeM = MXdlPerWave;
1370  constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
1371  constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
1372 
1373  // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
1374  // ScaleSliceSizeK is first dimension in C scale for packed math
1375  constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
1377 
1378  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1379  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1380  constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
1381  auto a_thread_offset = get_thread_local_1d_id() % MPerXdl +
1382  (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl;
1383 
1384  constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1386 
1387  constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
1389 
1390  // get each thread's offset in the scale tensor
1391  // A scale
1392  const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
1393 
1394  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
1395  return;
1396  StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
1397  static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
1398  const index_t fused_token =
1399  p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
1400  index_t token_offset = fused_token & 0xffffff;
1401  if constexpr(!IsInputGemm)
1402  {
1403  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1404  }
1405  scale_gather_offsets(m0) =
1406  token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK);
1407  });
1408 
1409  auto a_scale_thread_copy =
1411  AScaleType,
1412  decltype(a_scale_grid_desc_am_ak),
1413  decltype(a_scale_thread_desc),
1416  1,
1417  ScaleSliceSizeK,
1418  1,
1419  false,
1420  MXdlPerWave>(
1421  a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
1422 
1423  auto b_scale_thread_copy =
1425  BScaleType,
1426  decltype(b_scale_grid_desc_bn_ak),
1427  decltype(b_scale_thread_desc),
1430  1,
1431  ScaleSliceSizeK,
1432  1,
1433  false>(
1434  b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1435 
1436  // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
1437  constexpr auto a_scale_thread_slice_copy_step =
1438  make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
1439  constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
1440 
1441  constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
1442  if constexpr(IsInputGemm)
1443  {
1444  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1445  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1446  p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1447  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1448  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1449  BDataType,
1450  BDataType,
1451  decltype(b_grid_desc_bpreshuffled),
1452  decltype(b_block_desc_bk0_n_bk1),
1455  3,
1456  BBlockTransferSrcScalarPerVector,
1457  BThreadTransferSrcResetCoordinateAfterRun,
1458  true>(b_grid_desc_bpreshuffled,
1459  make_multi_index(n_block_data_idx_on_grid,
1461  0,
1462  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1463  const BScaleType* p_b_scale_grid_up =
1464  p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
1465  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1466  p_b_scale_grid_up + expert_id * expert_scale_stride,
1467  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1468  auto b_scale_thread_copy_up =
1470  BScaleType,
1471  decltype(b_scale_grid_desc_bn_ak),
1472  decltype(b_scale_thread_desc),
1475  1,
1476  ScaleSliceSizeK,
1477  1,
1478  false>(
1479  b_scale_grid_desc_bn_ak,
1480  make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1481 
1482  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1483  a_grid_desc_ak0_m_ak1,
1484  a_block_desc_ak0_m_ak1,
1485  a_blockwise_copy,
1486  a_grid_buf,
1487  a_block_buf,
1488  a_block_slice_copy_step,
1489 
1490  b_grid_desc_bpreshuffled,
1491  b_block_desc_bk0_n_bk1,
1492  b_blockwise_copy,
1493  b_blockwise_copy_up,
1494  b_grid_buf,
1495  b_grid_buf_up,
1496  b_block_buf,
1497  b_block_slice_copy_step,
1498 
1499  c_scale_thread_desc,
1500  c_thread_buf,
1501  c_thread_buf_up,
1502 
1503  a_scale_grid_desc_am_ak,
1504  a_scale_thread_desc,
1505  a_scale_thread_copy,
1506  a_scale_grid_buf,
1507  a_scale_thread_slice_copy_step,
1508 
1509  b_scale_grid_desc_bn_ak,
1510  b_scale_thread_desc,
1511  b_scale_thread_copy,
1512  b_scale_thread_copy_up,
1513  b_scale_grid_buf,
1514  b_scale_grid_buf_up,
1515  b_scale_thread_slice_copy_step,
1516 
1517  num_k_block_main_loop);
1518  }
1519  else
1520  {
1521  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1522  a_grid_desc_ak0_m_ak1,
1523  a_block_desc_ak0_m_ak1,
1524  a_blockwise_copy,
1525  a_grid_buf,
1526  a_block_buf,
1527  a_block_slice_copy_step,
1528 
1529  b_grid_desc_bpreshuffled,
1530  b_block_desc_bk0_n_bk1,
1531  b_blockwise_copy,
1532  b_grid_buf,
1533  b_block_buf,
1534  b_block_slice_copy_step,
1535 
1536  c_scale_thread_desc,
1537  c_thread_buf,
1538 
1539  a_scale_grid_desc_am_ak,
1540  a_scale_thread_desc,
1541  a_scale_thread_copy,
1542  a_scale_grid_buf,
1543  a_scale_thread_slice_copy_step,
1544 
1545  b_scale_grid_desc_bn_ak,
1546  b_scale_thread_desc,
1547  b_scale_thread_copy,
1548  b_scale_grid_buf,
1549  b_scale_thread_slice_copy_step,
1550 
1551  num_k_block_main_loop);
1552  }
1553 
1554  // shuffle C and write out
1555  {
1556  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1557  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1558  "wrong!");
1559 
1560  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1561 
1562  // transposed XDL
1563  // TODO: hacky, fix it!
1564  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1565  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1566 
1567  // TODO: hacky, fix it!
1568  // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
1569  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1570  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1571 
1572  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
1573  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
1574  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
1575  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
1576  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
1577  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
1578  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
1579  constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
1580 
1581  static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
1582  static_assert(M0 * M1 * M2 == MPerBlock);
1583  static_assert(N4 == 4 || N4 == 8);
1584  const index_t m1 = get_warp_local_1d_id() / NWave;
1585  const index_t m2 = threadIdx.x % get_warp_size() % M2;
1586 
1587  float topk_weight;
1588  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1589  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1590  if constexpr(MulRoutedWeight)
1591  {
1592  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
1593  topk_weight = p_ds_grid[I0][m_pos];
1594  }
1595  static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
1596  static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
1597  constexpr index_t c_offset =
1598  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1599  make_tuple(m0, n0, n2 * N4 + n4));
1600  constexpr auto cidx = Number<c_offset>{};
1601  if constexpr(IsInputGemm) // gu fusion, elementwise
1602  {
1603  if constexpr(ActivationOperation == Activation::silu_and_mul)
1604  {
1605  float gate = c_thread_buf[cidx];
1606  float up = c_thread_buf_up[cidx];
1607  if constexpr(MulRoutedWeight)
1608  {
1609  gate = gate * topk_weight;
1610  up = up * topk_weight;
1611  }
1613  {
1614  gate *= 16;
1615  up *= 16;
1616  }
1618  c_thread_buf(cidx) = gate * up;
1619  }
1620  else if(ActivationOperation == Activation::gelu_and_mul)
1621  {
1622  float gate = c_thread_buf[cidx];
1623  float up = c_thread_buf_up[cidx];
1624  if constexpr(MulRoutedWeight)
1625  {
1626  gate = gate * topk_weight;
1627  up = up * topk_weight;
1628  }
1630  {
1631  gate *= 16;
1632  up *= 16;
1633  }
1635  c_thread_buf(cidx) = gate * up;
1636  }
1637  }
1638  else
1639  {
1640  if constexpr(MulRoutedWeight)
1641  {
1642  c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
1643  }
1644  }
1645  });
1646  });
1647  });
1648  });
1649 
1650  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1652 
1653  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1654  static_cast<CShuffleDataType*>(p_shared),
1655  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1656 
1657  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
1658  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1659  make_tuple(
1662  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1663  M1, // M1 = MWave
1664  M2)), // M2 = MPerXdl
1667  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1668  N1, // N1 = NWave
1669  N2, // N2 * N3 * N4 = NPerXdl
1670  N3,
1671  N4))),
1673  make_tuple(
1675 
1676  // calculate origin of thread output tensor on global memory
1677  // blockwise GEMM c matrix starting index
1678  const auto c_thread_mtx_on_block =
1679  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1680 
1681  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1682  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1683 
1684  const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1688  make_tuple(Sequence<0>{}));
1689 
1690  const auto m_thread_data_on_block_idx =
1691  m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1692  make_multi_index(m_thread_data_on_block));
1693 
1694  const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1696  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
1698  make_tuple(Sequence<0>{}));
1699 
1700  const auto n_thread_data_on_block_idx =
1701  n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1702  make_multi_index(n_thread_data_on_block));
1703 
1704  // shuffle: threadwise copy C from VGPR to LDS
1705  auto c_thread_copy_vgpr_to_lds =
1707  CShuffleDataType,
1708  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1709  decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1711  Sequence<CShuffleMXdlPerWavePerShuffle,
1712  CShuffleNXdlPerWavePerShuffle,
1713  I1,
1714  I1,
1715  I1,
1716  N2,
1717  I1,
1718  N4>,
1720  7,
1721  1,
1723  1,
1724  true>{
1725  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1726  make_multi_index(0,
1727  0,
1728  m_thread_data_on_block_idx[I1],
1729  n_thread_data_on_block_idx[I1],
1730  m_thread_data_on_block_idx[I2],
1731  n_thread_data_on_block_idx[I2],
1732  n_thread_data_on_block_idx[I3],
1733  n_thread_data_on_block_idx[I4]),
1735 
1736  using EDataType = CDataType;
1737 
1738  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1739  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1740 
1741  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1743  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1744 
1745  const auto ds_grid_buf = generate_tuple(
1746  [&](auto i) {
1747  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
1748  const DDataType* ptr_ = p_ds_grid[i];
1749  // hack logic here to support different kind of strides. todo fix it.
1750  // ascale t, 1; bscale E, N, 1, move ptr to E
1751  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1752  ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
1753  },
1754  Number<NumDTensor>{});
1755 
1756  // tuple of reference to C/Ds tensor descriptors
1757  const auto c_ds_desc_refs = concat_tuple_of_reference(
1758  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1759  generate_tie([&](auto i) -> const auto& // return type should be reference
1760  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1761  Number<NumDTensor>{}));
1762 
1763  // tuple of reference to C/Ds tensor descriptors
1764  const auto c_ds_buf_refs = concat_tuple_of_reference(
1765  tie(c_shuffle_block_buf),
1766  generate_tie([&](auto i) -> const auto& // return type should be reference
1767  { return ds_grid_buf[i]; },
1768  Number<NumDTensor>{}));
1769 
1770  // tuple of starting index of C/Ds blockwise copy
1771  const auto idx_c_ds_block_begin =
1774  [&](auto) {
1775  return make_multi_index(block_m_id, 0, block_n_id, 0);
1776  // return make_multi_index(block_work_idx[I0], 0,
1777  // block_work_idx[I1], 0);
1778  },
1779  Number<NumDTensor>{}));
1780 
1781  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1782  c_grid_desc_mblock_mperblock_nblock_nperblock;
1783 
1784  using CDEBlockTransferCluster =
1785  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1786  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1787  constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
1788  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1790  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1792  decltype(c_ds_desc_refs),
1793  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1794  CElementwiseOperation,
1795  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1796  // support arbitray type
1797  Sequence<1,
1798  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1799  1,
1800  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1801  CDEBlockTransferCluster,
1802  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1803  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1804  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1805  3, // index_t SrcVectorDim,
1806  3, // index_t DstVectorDim,
1807  CDEShuffleBlockTransferScalarPerVectors,
1812  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1813  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1814  IndexType,
1815  1, // ScatterDim
1816  true, // OutputScatter: false, only use scatter weights
1817  scatter_weight_idx // ScatterWeightIdx: ascale
1818  >{c_ds_desc_refs,
1819  idx_c_ds_block_begin,
1820  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1821  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1822  c_element_op};
1823 
1824  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1825  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1826  // space filling curve for threadwise C in VGPR
1827  constexpr auto sfc_c_vgpr =
1830  Sequence<CShuffleMXdlPerWavePerShuffle,
1831  CShuffleNXdlPerWavePerShuffle,
1832  1,
1833  1,
1834  1,
1835  N2,
1836  1,
1837  N4>>{};
1838 
1839  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1840 
1841  // space filling curve for shuffled blockwise C/D/E
1842  constexpr auto sfc_cde_block =
1845  Sequence<1,
1846  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1847  1,
1848  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1849 
1850  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1851  constexpr auto EMThreads =
1852  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1853  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1854  constexpr auto ENThreads =
1855  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1856  static_for<0, num_access, 1>{}([&](auto access_id) {
1857  // make sure it's safe to write to LDS
1859 
1860  auto dstidx = sfc_cde_block.GetIndex(access_id);
1861  const index_t c_token_pos =
1862  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1863  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1864  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1865  index_t token_offset = fused_token & 0xffffff;
1866  if constexpr(IsInputGemm)
1867  {
1868  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1869  }
1870  scatter_offsets(m0) = token_offset * problem.N;
1871  });
1872 
1873  block_sync_lds();
1874 
1875  // each thread write its data from VGPR to LDS
1876  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1877  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1878  c_thread_buf,
1879  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1880  c_shuffle_block_buf);
1881 
1882  // make sure it's safe to read from LDS
1883  block_sync_lds();
1884 
1885  // each block copy its data from LDS to global
1886  cde_block_copy_lds_and_global.Run(
1887  c_ds_desc_refs,
1888  c_ds_buf_refs,
1889  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1890  tie(c_grid_buf),
1891  scatter_offsets);
1892 
1893  if constexpr(access_id < num_access - 1)
1894  {
1895  constexpr auto cde_lds_and_global_step =
1896  sfc_cde_block.GetForwardStep(access_id);
1897 
1898  // move on Ds
1899  static_for<0, NumDTensor, 1>{}([&](auto i) {
1900  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1901  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1902  });
1903 
1904  // move on E
1905  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1906  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1907  I0,
1908  cde_lds_and_global_step);
1909  }
1910  });
1911  }
1912  }
1913 
1914  template <bool HasMainKBlockLoop,
1915  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1916  TailNumber TailNum = TailNumber::Odd>
1917  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1918  const index_t* p_sorted_expert_ids,
1919  const index_t* p_max_token_id,
1920  const ADataType* p_a_grid,
1921  const BDataType* p_b_grid,
1922  DsGridPointer& p_ds_grid,
1923  CDataType* p_c_grid,
1924  const AScaleType* p_a_scale_grid,
1925  const BScaleType* p_b_scale_grid,
1926  void* p_shared,
1927  void* p_shared1,
1928  const Problem& problem,
1929  AElementwiseOperation a_element_op,
1930  BElementwiseOperation b_element_op,
1931  CElementwiseOperation c_element_op)
1932  {
1933  ignore = b_element_op;
1934  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1935  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1936  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1937  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1938  problem.MPadded,
1939  problem.K,
1940  problem.KPadded,
1941  problem.StrideA,
1942  problem.AK0);
1943  const auto b_grid_desc_bpreshuffled =
1944  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1945  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1946  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1947  problem.MPadded,
1948  problem.N,
1949  problem.NPadded,
1950  problem.StrideC);
1951 
1952  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1953  make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
1954  : problem.NumTokens * problem.TopK,
1955  ScaleBlockM),
1956  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1957  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1958  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1959  make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1960  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1961  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1962  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1964  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1965  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1966  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1967  if(expert_block_id * MPerBlock >= max_token_id)
1968  return;
1969  const index_t expert_id =
1970  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1971  const auto block_mn = [&]() -> std::pair<int, int> {
1972  if constexpr(NSwizzle)
1973  {
1974  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1975  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1976  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1977  const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
1978  const index_t bid_new = blockIdx.x - prefix_block;
1979  const index_t nid = __builtin_amdgcn_readfirstlane(
1980  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1981  const index_t mid =
1982  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1983  return {nid, mid};
1984  }
1985  else
1986  {
1987  return {blockIdx.x, blockIdx.y};
1988  }
1989  }();
1990  const index_t block_n_id = block_mn.first;
1991  const index_t block_m_id = block_mn.second;
1992 
1993  const index_t token0 =
1994  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1995 
1996  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1997  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1998  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1999  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2000  constexpr auto AKThreads = AK0Threads * AK1Threads;
2001  constexpr auto AMRepeats = MPerBlock / AMThreads;
2002  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2003 
2004  if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
2005  token0 >= problem.NumTokens)
2006  return;
2008  gather_offsets; //= p_sorted_token_ids[token_pos];
2009  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2010  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2011  index_t token_offset = fused_token & 0xffffff;
2012  if constexpr(!IsInputGemm)
2013  {
2014  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2015  }
2016  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2017  });
2018  const index_t expert_stride =
2019  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2020  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2021  math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
2022  math::integer_divide_ceil(problem.K, ScaleBlockK));
2023  // N0, K0, Blocksize*KPack
2024  const index_t n_block_data_idx_on_grid =
2025  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2026 
2027  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2028  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2029  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2030  p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2031  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2032 
2033  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2034  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2035  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2036  p_b_scale_grid + expert_id * expert_scale_stride,
2037  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2038 
2039  // A matrix in LDS memory, dst of blockwise copy
2040  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2041 
2042  // B matrix in LDS memory, dst of blockwise copy
2043  // dummy
2044  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2045  // A matrix blockwise copy
2046  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2048  AElementwiseOperation,
2052  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2053  ABlockTransferThreadClusterArrangeOrder,
2054  ADataType,
2055  LDSTypeA,
2056  decltype(a_grid_desc_ak0_m_ak1),
2057  decltype(a_block_desc_ak0_m_ak1),
2058  ABlockTransferSrcAccessOrder,
2060  ABlockTransferSrcVectorDim,
2061  2,
2062  ABlockTransferSrcScalarPerVector,
2063  ABlockTransferDstScalarPerVector_AK1,
2064  1,
2065  1,
2066  AThreadTransferSrcResetCoordinateAfterRun,
2067  true,
2068  IndexType,
2069  1,
2070  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2071  make_multi_index(0, 0, 0),
2072  a_element_op,
2073  a_block_desc_ak0_m_ak1,
2074  make_multi_index(0, 0, 0),
2076  gather_offsets);
2077 
2078  // Thread-wise copy
2079  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2080  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2081  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2082  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2083  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2084  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2085 
2086  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
2087  BDataType,
2088  BDataType,
2089  decltype(b_grid_desc_bpreshuffled),
2090  decltype(b_block_desc_bk0_n_bk1),
2093  3,
2094  BBlockTransferSrcScalarPerVector,
2095  BThreadTransferSrcResetCoordinateAfterRun,
2096  true>(b_grid_desc_bpreshuffled,
2097  make_multi_index(n_block_data_idx_on_grid,
2099  0,
2100  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2101 
2102  // LDS allocation for A and B: be careful of alignment
2103  // Cast after lds
2104  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2105  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2106  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2107  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2108  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2109 
2110  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2111  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
2112 
2113  // Blockwise GEMM pipeline
2114  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2115  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2116  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2117  decltype(c_thread_buf) c_thread_buf_up;
2118 
2119  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2120  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2121  KPerBlock);
2122 
2123  // scale
2124  constexpr index_t ScaleSliceSizeM = MXdlPerWave;
2125  constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
2126  constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
2127 
2128  // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
2129  // ScaleSliceSizeK is first dimension in C scale for packed math
2130  constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
2132 
2133  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
2134  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
2135  constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
2136  auto a_thread_offset = get_thread_local_1d_id() % MPerXdl +
2137  (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl;
2138 
2139  constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
2141 
2142  constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
2144 
2145  // get each thread's offset in the scale tensor
2146  // A scale
2147  const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
2148 
2149  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2150  return;
2151  StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
2152  static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
2153  const index_t fused_token =
2154  p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
2155  index_t token_offset = fused_token & 0xffffff;
2156  if constexpr(!IsInputGemm)
2157  {
2158  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2159  }
2160  scale_gather_offsets(m0) = static_cast<IndexType>(token_offset) *
2161  math::integer_divide_ceil(problem.K, ScaleBlockK);
2162  });
2163 
2164  auto a_scale_thread_copy =
2166  AScaleType,
2167  decltype(a_scale_grid_desc_am_ak),
2168  decltype(a_scale_thread_desc),
2171  1,
2172  ScaleSliceSizeK,
2173  1,
2174  false,
2175  MXdlPerWave>(
2176  a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
2177 
2178  auto b_scale_thread_copy =
2180  BScaleType,
2181  decltype(b_scale_grid_desc_bn_ak),
2182  decltype(b_scale_thread_desc),
2185  1,
2186  ScaleSliceSizeK,
2187  1,
2188  false>(
2189  b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2190 
2191  // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
2192  constexpr auto a_scale_thread_slice_copy_step =
2193  make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
2194  constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
2195 
2196  constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
2197  if constexpr(IsInputGemm)
2198  {
2199  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2200  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2201  p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2202  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2203  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2204  BDataType,
2205  BDataType,
2206  decltype(b_grid_desc_bpreshuffled),
2207  decltype(b_block_desc_bk0_n_bk1),
2210  3,
2211  BBlockTransferSrcScalarPerVector,
2212  BThreadTransferSrcResetCoordinateAfterRun,
2213  true>(b_grid_desc_bpreshuffled,
2214  make_multi_index(n_block_data_idx_on_grid,
2216  0,
2217  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2218  const BScaleType* p_b_scale_grid_up =
2219  p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
2220  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2221  p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
2222  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2223  auto b_scale_thread_copy_up =
2225  BScaleType,
2226  decltype(b_scale_grid_desc_bn_ak),
2227  decltype(b_scale_thread_desc),
2230  1,
2231  ScaleSliceSizeK,
2232  1,
2233  false>(
2234  b_scale_grid_desc_bn_ak,
2235  make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2236 
2237  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2238  a_grid_desc_ak0_m_ak1,
2239  a_block_desc_ak0_m_ak1,
2240  a_blockwise_copy,
2241  a_grid_buf,
2242  a_block_bufs,
2243  a_block_slice_copy_step,
2244  b_grid_desc_bpreshuffled,
2245  b_block_desc_bk0_n_bk1,
2246  b_blockwise_copy,
2247  b_blockwise_copy_up,
2248  b_grid_buf,
2249  b_grid_buf_up,
2250  b_block_bufs,
2251  b_block_slice_copy_step,
2252  c_scale_thread_desc,
2253  c_thread_buf,
2254  c_thread_buf_up,
2255  a_scale_grid_desc_am_ak,
2256  a_scale_thread_desc,
2257  a_scale_thread_copy,
2258  a_scale_grid_buf,
2259  a_scale_thread_slice_copy_step,
2260  b_scale_grid_desc_bn_ak,
2261  b_scale_thread_desc,
2262  b_scale_thread_copy,
2263  b_scale_thread_copy_up,
2264  b_scale_grid_buf,
2265  b_scale_grid_buf_up,
2266  b_scale_thread_slice_copy_step,
2267  num_k_block_main_loop);
2268  }
2269  else
2270  {
2271  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2272  a_grid_desc_ak0_m_ak1,
2273  a_block_desc_ak0_m_ak1,
2274  a_blockwise_copy,
2275  a_grid_buf,
2276  a_block_bufs,
2277  a_block_slice_copy_step,
2278  b_grid_desc_bpreshuffled,
2279  b_block_desc_bk0_n_bk1,
2280  b_blockwise_copy,
2281  b_grid_buf,
2282  b_block_bufs,
2283  b_block_slice_copy_step,
2284  c_scale_thread_desc,
2285  c_thread_buf,
2286  a_scale_grid_desc_am_ak,
2287  a_scale_thread_desc,
2288  a_scale_thread_copy,
2289  a_scale_grid_buf,
2290  a_scale_thread_slice_copy_step,
2291  b_scale_grid_desc_bn_ak,
2292  b_scale_thread_desc,
2293  b_scale_thread_copy,
2294  b_scale_grid_buf,
2295  b_scale_thread_slice_copy_step,
2296  num_k_block_main_loop);
2297  }
2298 
2299  // shuffle C and write out
2300  {
2301 
2302  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2303  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2304  "wrong!");
2305 
2306  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2307 
2308  // transposed XDL
2309  // TODO: hacky, fix it!
2310  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
2311  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2312 
2313  // TODO: hacky, fix it!
2314  // only used to get lengths
2315  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
2316  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2317 
2318  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
2319  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
2320  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
2321  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
2322  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
2323  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
2324  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
2325  constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
2326 
2327  static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
2328  static_assert(M0 * M1 * M2 == MPerBlock);
2329  static_assert(N4 == 4 || N4 == 8);
2330  const index_t m1 = get_warp_local_1d_id() / NWave;
2331  const index_t m2 = threadIdx.x % get_warp_size() % M2;
2332 
2333  float topk_weight;
2334  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2335  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2336  if constexpr(MulRoutedWeight)
2337  {
2338  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
2339  topk_weight = p_ds_grid[I0][m_pos];
2340  }
2341  static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
2342  static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
2343  constexpr index_t c_offset =
2344  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2345  make_tuple(m0, n0, n2 * N4 + n4));
2346  constexpr auto cidx = Number<c_offset>{};
2347  if constexpr(IsInputGemm) // gu fusion, elementwise
2348  {
2349  if constexpr(ActivationOperation == Activation::silu_and_mul)
2350  {
2351  float gate = c_thread_buf[cidx];
2352  float up = c_thread_buf_up[cidx];
2353  if constexpr(MulRoutedWeight)
2354  {
2355  gate = gate * topk_weight;
2356  up = up * topk_weight;
2357  }
2359  {
2360  gate *= 16;
2361  up *= 16;
2362  }
2364  c_thread_buf(cidx) = gate * up;
2365  }
2366  else if(ActivationOperation == Activation::gelu_and_mul)
2367  {
2368  float gate = c_thread_buf[cidx];
2369  float up = c_thread_buf_up[cidx];
2370  if constexpr(MulRoutedWeight)
2371  {
2372  gate = gate * topk_weight;
2373  up = up * topk_weight;
2374  }
2376  {
2377  gate *= 16;
2378  up *= 16;
2379  }
2381  c_thread_buf(cidx) = gate * up;
2382  }
2383  }
2384  else
2385  {
2386  if constexpr(MulRoutedWeight)
2387  {
2388  c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
2389  }
2390  }
2391 
2392  });
2393  });
2394  });
2395  });
2396 
2397  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2399 
2400  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2401  static_cast<CShuffleDataType*>(p_shared),
2402  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2403 
2404  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
2405  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2406  make_tuple(
2409  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2410  M1, // M1 = MWave
2411  M2)), // M2 = MPerXdl
2414  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2415  N1, // N1 = NWave
2416  N2, // N2 * N3 * N4 = NPerXdl
2417  N3,
2418  N4))),
2420  make_tuple(
2422 
2423  // calculate origin of thread output tensor on global memory
2424  // blockwise GEMM c matrix starting index
2425  const auto c_thread_mtx_on_block =
2426  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2427 
2428  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2429  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2430 
2431  const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
2435  make_tuple(Sequence<0>{}));
2436 
2437  const auto m_thread_data_on_block_idx =
2438  m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
2439  make_multi_index(m_thread_data_on_block));
2440 
2441  const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
2443  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
2445  make_tuple(Sequence<0>{}));
2446 
2447  const auto n_thread_data_on_block_idx =
2448  n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
2449  make_multi_index(n_thread_data_on_block));
2450 
2451  // shuffle: threadwise copy C from VGPR to LDS
2452  auto c_thread_copy_vgpr_to_lds =
2454  CShuffleDataType,
2455  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2456  decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2458  Sequence<CShuffleMXdlPerWavePerShuffle,
2459  CShuffleNXdlPerWavePerShuffle,
2460  I1,
2461  I1,
2462  I1,
2463  N2,
2464  I1,
2465  N4>,
2467  7,
2468  1,
2470  1,
2471  true>{
2472  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2473  make_multi_index(0,
2474  0,
2475  m_thread_data_on_block_idx[I1],
2476  n_thread_data_on_block_idx[I1],
2477  m_thread_data_on_block_idx[I2],
2478  n_thread_data_on_block_idx[I2],
2479  n_thread_data_on_block_idx[I3],
2480  n_thread_data_on_block_idx[I4]),
2482 
2483  using EDataType = CDataType;
2484 
2485  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2486  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2487 
2488  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2490  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2491 
2492  const auto ds_grid_buf = generate_tuple(
2493  [&](auto i) {
2494  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2495  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2496  },
2497  Number<NumDTensor>{});
2498 
2499  // tuple of reference to C/Ds tensor descriptors
2500  const auto c_ds_desc_refs = concat_tuple_of_reference(
2501  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2502  generate_tie([&](auto i) -> const auto& // return type should be reference
2503  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2504  Number<NumDTensor>{}));
2505 
2506  // tuple of reference to C/Ds tensor descriptors
2507  const auto c_ds_buf_refs = concat_tuple_of_reference(
2508  tie(c_shuffle_block_buf),
2509  generate_tie([&](auto i) -> const auto& // return type should be reference
2510  { return ds_grid_buf[i]; },
2511  Number<NumDTensor>{}));
2512 
2513  // tuple of starting index of C/Ds blockwise copy
2514  const auto idx_c_ds_block_begin =
2517  [&](auto) {
2518  return make_multi_index(block_m_id, 0, block_n_id, 0);
2519  // return make_multi_index(block_work_idx[I0], 0,
2520  // block_work_idx[I1], 0);
2521  },
2522  Number<NumDTensor>{}));
2523 
2524  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2525  c_grid_desc_mblock_mperblock_nblock_nperblock;
2526 
2527  using CDEBlockTransferCluster =
2528  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2529  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2530  constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
2531  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2533  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2535  decltype(c_ds_desc_refs),
2536  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2537  CElementwiseOperation,
2538  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2539  // support arbitray type
2540  Sequence<1,
2541  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2542  1,
2543  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2544  CDEBlockTransferCluster,
2545  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2546  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2547  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2548  3, // index_t SrcVectorDim,
2549  3, // index_t DstVectorDim,
2550  CDEShuffleBlockTransferScalarPerVectors,
2555  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2556  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2557  IndexType,
2558  1, // ScatterDim
2559  true, // OutputScatter: false, only use scatter weights
2560  scatter_weight_idx // ScatterWeightIdx: ascale
2561  >{c_ds_desc_refs,
2562  idx_c_ds_block_begin,
2563  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2564  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2565  c_element_op};
2566 
2567  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2568  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2569  // space filling curve for threadwise C in VGPR
2570  constexpr auto sfc_c_vgpr =
2573  Sequence<CShuffleMXdlPerWavePerShuffle,
2574  CShuffleNXdlPerWavePerShuffle,
2575  1,
2576  1,
2577  1,
2578  N2,
2579  1,
2580  N4>>{};
2581 
2582  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2583 
2584  // space filling curve for shuffled blockwise C/D/E
2585  constexpr auto sfc_cde_block =
2588  Sequence<1,
2589  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2590  1,
2591  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2592 
2593  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2594  constexpr auto EMThreads =
2595  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2596  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2597  constexpr auto ENThreads =
2598  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2599  static_for<0, num_access, 1>{}([&](auto access_id) {
2600  // make sure it's safe to write to LDS
2602  scatter_offsets; //= p_sorted_token_ids[c_token_pos];
2603 
2604  auto dstidx = sfc_cde_block.GetIndex(access_id);
2605  const index_t c_token_pos =
2606  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2607  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2608  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2609  index_t token_offset = fused_token & 0xffffff;
2610  if constexpr(IsInputGemm)
2611  {
2612  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2613  }
2614  scatter_offsets(m0) = token_offset * problem.N;
2615  });
2616 
2617  block_sync_lds();
2618 
2619  // each thread write its data from VGPR to LDS
2620  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2621  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2622  c_thread_buf,
2623  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2624  c_shuffle_block_buf);
2625 
2626  // make sure it's safe to read from LDS
2627  block_sync_lds();
2628 
2629  // each block copy its data from LDS to global
2630  cde_block_copy_lds_and_global.Run(
2631  c_ds_desc_refs,
2632  c_ds_buf_refs,
2633  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2634  tie(c_grid_buf),
2635  scatter_offsets);
2636 
2637  if constexpr(access_id < num_access - 1)
2638  {
2639  constexpr auto cde_lds_and_global_step =
2640  sfc_cde_block.GetForwardStep(access_id);
2641 
2642  // move on Ds
2643  static_for<0, NumDTensor, 1>{}([&](auto i) {
2644  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2645  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2646  });
2647 
2648  // move on E
2649  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2650  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2651  I0,
2652  cde_lds_and_global_step);
2653  }
2654  });
2655  }
2656  }
2657 };
2658 
2659 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:56
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
int64_t long_index_t
Definition: ck.hpp:299
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
Definition: gridwise_moe_gemm_blockscale.hpp:660
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm_blockscale.hpp:716
CDataType * p_c_grid
Definition: gridwise_moe_gemm_blockscale.hpp:722
const BScaleType * p_b_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:725
const index_t * p_max_token_id
Definition: gridwise_moe_gemm_blockscale.hpp:718
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm_blockscale.hpp:721
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:729
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm_blockscale.hpp:661
const ADataType * p_a_grid
Definition: gridwise_moe_gemm_blockscale.hpp:719
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:727
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm_blockscale.hpp:717
const BDataType * p_b_grid
Definition: gridwise_moe_gemm_blockscale.hpp:720
const AScaleType * p_a_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:724
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:728
Definition: gridwise_moe_gemm_blockscale.hpp:595
index_t K
Definition: gridwise_moe_gemm_blockscale.hpp:642
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm_blockscale.hpp:596
index_t TopK
Definition: gridwise_moe_gemm_blockscale.hpp:639
index_t NPadded
Definition: gridwise_moe_gemm_blockscale.hpp:649
index_t StrideB
Definition: gridwise_moe_gemm_blockscale.hpp:644
__host__ void Print() const
Definition: gridwise_moe_gemm_blockscale.hpp:627
index_t BK0
Definition: gridwise_moe_gemm_blockscale.hpp:653
index_t KRead
Definition: gridwise_moe_gemm_blockscale.hpp:650
index_t N
Definition: gridwise_moe_gemm_blockscale.hpp:641
index_t StrideC
Definition: gridwise_moe_gemm_blockscale.hpp:646
index_t KBatch
Definition: gridwise_moe_gemm_blockscale.hpp:647
index_t MBlock
Definition: gridwise_moe_gemm_blockscale.hpp:654
index_t KPadded
Definition: gridwise_moe_gemm_blockscale.hpp:651
index_t NumTokens
Definition: gridwise_moe_gemm_blockscale.hpp:638
index_t StrideA
Definition: gridwise_moe_gemm_blockscale.hpp:643
index_t AK0
Definition: gridwise_moe_gemm_blockscale.hpp:652
index_t M
Definition: gridwise_moe_gemm_blockscale.hpp:640
index_t MPadded
Definition: gridwise_moe_gemm_blockscale.hpp:648
index_t NBlock
Definition: gridwise_moe_gemm_blockscale.hpp:655
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm_blockscale.hpp:645
Definition: gridwise_moe_gemm_blockscale.hpp:733
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm_blockscale.hpp:734
index_t a_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:765
index_t b_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:766
Definition: gridwise_moe_gemm_blockscale.hpp:171
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:524
static constexpr index_t KPack
Definition: gridwise_moe_gemm_blockscale.hpp:196
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm_blockscale.hpp:887
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm_blockscale.hpp:411
static constexpr auto AK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:189
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:269
static constexpr auto BK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:190
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm_blockscale.hpp:230
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1917
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1133
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm_blockscale.hpp:769
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:260
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:285
static constexpr auto I3
Definition: gridwise_moe_gemm_blockscale.hpp:178
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:960
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:315
static constexpr index_t KLane
Definition: gridwise_moe_gemm_blockscale.hpp:209
remove_cvref_t< decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm_blockscale.hpp:936
float AScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:172
static constexpr auto AK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:187
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm_blockscale.hpp:211
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm_blockscale.hpp:329
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:1148
static constexpr auto I2
Definition: gridwise_moe_gemm_blockscale.hpp:177
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:232
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:265
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:274
static constexpr auto I4
Definition: gridwise_moe_gemm_blockscale.hpp:179
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm_blockscale.hpp:569
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:255
static constexpr auto I6
Definition: gridwise_moe_gemm_blockscale.hpp:181
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm_blockscale.hpp:184
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1140
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:518
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:246
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm_blockscale.hpp:228
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm_blockscale.hpp:894
float BScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:173
static constexpr auto I7
Definition: gridwise_moe_gemm_blockscale.hpp:182
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:304
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:581
static constexpr index_t NWave
Definition: gridwise_moe_gemm_blockscale.hpp:213
static constexpr index_t NLane
Definition: gridwise_moe_gemm_blockscale.hpp:212
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm_blockscale.hpp:193
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:297
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:309
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm_blockscale.hpp:938
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:291
static constexpr auto I0
Definition: gridwise_moe_gemm_blockscale.hpp:175
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:509
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:548
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm_blockscale.hpp:191
static constexpr index_t KGroup
Definition: gridwise_moe_gemm_blockscale.hpp:198
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:239
static constexpr auto I5
Definition: gridwise_moe_gemm_blockscale.hpp:180
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1169
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm_blockscale.hpp:217
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm_blockscale.hpp:421
static constexpr auto I1
Definition: gridwise_moe_gemm_blockscale.hpp:176
static constexpr auto BK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:188
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm_blockscale.hpp:215
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:279
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))> DsGridDesc_M_N
Definition: gridwise_moe_gemm_blockscale.hpp:592
Definition: xdlops_gemm.hpp:1126
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1700
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1647
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1694
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:440
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049