/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File
gridwise_moe_gemm_blockscale.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
18 
19 #define DEBUG_LOG 0
20 
21 namespace ck {
22 
23 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24 // kernel function Blockers:
25 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26 // two lds chunks.
27 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28 // buffer when we declare __shared__ inside blkgemmpipe
29 
31 {
32  gelu_and_mul = 0,
33  silu_and_mul = 1
34 };
35 
36 template <typename GridwiseGemm,
37  bool HasMainKBlockLoop,
38  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
39  index_t MinimumOccupancy = 1,
40  TailNumber TailNum = TailNumber::Even>
41 __global__ void
42 #if CK_USE_LAUNCH_BOUNDS
43 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
44 #endif
45  // __attribute__((amdgpu_waves_per_eu(1, 1)))
46  kernel_moe_gemm(typename GridwiseGemm::Argument karg)
47 {
48 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
49  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
50  {
51  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52 
53  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54 
55  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56  karg.p_sorted_token_ids,
57  karg.p_sorted_expert_ids,
58  karg.p_max_token_id,
59  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
61  karg.p_ds_grid,
62  karg.p_c_grid,
63  karg.p_a_scale_grid,
64  karg.p_b_scale_grid,
65  p_shared,
66  karg,
67  karg.a_element_op,
68  karg.b_element_op,
69  karg.c_element_op);
70  }
71 #else
72  ignore = karg;
73 #endif // end of if (defined(__gfx9__))
74 }
75 
76 template <typename GridwiseGemm,
77  bool HasMainKBlockLoop,
78  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
79  index_t MinimumOccupancy = 1,
80  TailNumber TailNum = TailNumber::Even>
81 __global__ void
82 #if CK_USE_LAUNCH_BOUNDS
83 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
84 #endif
85  // __attribute__((amdgpu_waves_per_eu(1, 1)))
86  kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
87 {
88 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
89  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
90  {
91  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
92  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
93 
94  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
95 
96  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
97  karg.p_sorted_token_ids,
98  karg.p_sorted_expert_ids,
99  karg.p_max_token_id,
100  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
101  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
102  karg.p_ds_grid,
103  karg.p_c_grid,
104  karg.p_a_scale_grid,
105  karg.p_b_scale_grid,
106  p_shared,
107  p_shared1,
108  karg,
109  karg.a_element_op,
110  karg.b_element_op,
111  karg.c_element_op);
112  }
113 #else
114  ignore = karg;
115 #endif // end of if (defined(__gfx9__))
116 }
117 
118 template <typename ALayout,
119  typename BLayout,
120  typename DsLayout,
121  typename CLayout,
122  typename ADataType,
123  typename BDataType,
124  typename AccDataType,
125  typename CShuffleDataType,
126  typename DsDataType,
127  typename CDataType,
128  typename AElementwiseOperation,
129  typename BElementwiseOperation,
130  typename CElementwiseOperation,
132  index_t BlockSize,
133  index_t ScaleBlockM,
134  index_t ScaleBlockN,
135  index_t ScaleBlockK,
136  index_t MPerBlock,
137  index_t NPerBlock,
138  index_t KPerBlock,
139  index_t AK1Value,
140  index_t BK1Value,
141  index_t MPerXdl,
142  index_t NPerXdl,
143  index_t MXdlPerWave,
144  index_t NXdlPerWave,
145  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
146  typename ABlockTransferThreadClusterArrangeOrder,
147  typename ABlockTransferSrcAccessOrder,
148  index_t ABlockTransferSrcVectorDim,
149  index_t ABlockTransferSrcScalarPerVector,
150  index_t ABlockTransferDstScalarPerVector_AK1,
151  bool AThreadTransferSrcResetCoordinateAfterRun,
152  index_t ABlockLdsExtraM,
153  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
154  typename BBlockTransferThreadClusterArrangeOrder,
155  typename BBlockTransferSrcAccessOrder,
156  index_t BBlockTransferSrcVectorDim,
157  index_t BBlockTransferSrcScalarPerVector,
158  index_t BBlockTransferDstScalarPerVector_BK1,
159  bool BThreadTransferSrcResetCoordinateAfterRun,
160  index_t BBlockLdsExtraN,
161  index_t CShuffleMXdlPerWavePerShuffle,
162  index_t CShuffleNXdlPerWavePerShuffle,
163  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
164  typename CDEShuffleBlockTransferScalarPerVectors,
167  index_t ActivationOperation = 0,
168  bool NSwizzle = false,
169  bool IsInputGemm = true,
170  bool MulRoutedWeight = true,
171  typename IndexType = index_t,
172  typename ComputeTypeA = CDataType,
173  typename ComputeTypeB = ComputeTypeA,
174  typename LDSTypeA = ADataType,
175  typename LDSTypeB = BDataType>
177 {
178  using AScaleType = float;
179  using BScaleType = float;
180 
181  static constexpr auto I0 = Number<0>{};
182  static constexpr auto I1 = Number<1>{};
183  static constexpr auto I2 = Number<2>{};
184  static constexpr auto I3 = Number<3>{};
185  static constexpr auto I4 = Number<4>{};
186  static constexpr auto I5 = Number<5>{};
187  static constexpr auto I6 = Number<6>{};
188  static constexpr auto I7 = Number<7>{};
189 
191  CDEShuffleBlockTransferScalarPerVectors{}[I0];
192  // K1 should be Number<...>
193  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
194  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
195  static constexpr auto AK1Number = Number<AK1Value>{};
196  static constexpr auto BK1Number = Number<BK1Value>{};
197  static constexpr auto BlockSizeNumber = Number<BlockSize>{};
198 
199  static constexpr index_t NumDTensor = DsDataType::Size();
200 
202  static constexpr index_t KPack =
204  static constexpr index_t KGroup = []() {
206  // On gfx950, we have a mfma that required 32 f8 elements as input,
207  // splited into 2 groups of 16 f8 elements.
208  // the 2 groups is not contiguous in the B preshuffed layout.
209  // and we do not want it to be contiguous in the B preshuffled layout
210  // because a memory instruction can only read 16 f8 elements at a time.
211  return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
212  else
213  return 1;
214  }();
215  static constexpr index_t KLane =
217  static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
218  static constexpr index_t NLane = NPerXdl;
219  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
220  // static constexpr index_t NumTokens = 1;
221  static constexpr index_t SortedTileSize = MPerBlock;
222 
223  static constexpr auto MakeDsGridPointer()
224  {
225  return generate_tuple(
226  [&](auto i) {
227  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
228 
229  return static_cast<const DDataType*>(nullptr);
230  },
232  }
233 
234  using DsGridPointer = decltype(MakeDsGridPointer());
235 
237 
238  static constexpr index_t APackedSize = []() {
240  return 2;
241  else
242  return 1;
243  }();
244 
245  static constexpr index_t BPackedSize = []() {
247  return 2;
248  else
249  return 1;
250  }();
251 
252  __host__ static auto CalculateGridSize(index_t M, index_t N)
253  {
254  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
255  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
256  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
257  const index_t gridy = NSwizzle ? 1 : mblock;
258  return std::make_tuple(gridx, gridy, 1);
259  }
260 
261  __host__ __device__ static auto CalculateMPadded(index_t M)
262  {
263  return math::integer_least_multiple(M, MPerBlock);
264  }
265 
266  __host__ __device__ static auto CalculateNPadded(index_t N)
267  {
268  return math::integer_least_multiple(N, NPerBlock);
269  }
270 
271  __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
272  {
273  return math::integer_divide_ceil(N, NLane);
274  }
275  __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
276  {
278  }
279 
280  __host__ __device__ static auto CalculateKPadded(index_t K)
281  {
282  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
283  }
284 
285  __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
286  {
287  auto K_t = K_Batch * KPerBlock;
288  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
289  }
290 
291  __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
292  {
293  auto K_t = K_Batch * KPerBlock;
294  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
295  }
296 
297  __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
298  {
299  auto K_t = K_Batch * KPerBlock;
300  return (K + K_t - 1) / K_t * KPerBlock;
301  }
302 
303  __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
304  {
305  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
306  auto K_t = K_Batch * KReadVec;
307  return (K + K_t - 1) / K_t * KReadVec;
308  }
309 
310  __host__ __device__ static auto CalculateMBlock(index_t M)
311  {
312  return math::integer_divide_ceil(M, MPerBlock);
313  }
314 
315  __host__ __device__ static auto CalculateNBlock(index_t N)
316  {
317  return math::integer_divide_ceil(N, NPerBlock);
318  }
319 
320  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
321  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
322  {
323  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
324  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
325 
327  TileDesc_K0_MN_K1{},
333  }
334 
335  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
336  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
337  {
338  const auto a_grid_desc_mraw_kraw = [&]() {
339  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
340  {
341  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
342  }
343  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
344  {
345  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
346  }
347  }();
348 
350 
351  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
352  GemmSpec == GemmSpecialization::MNKPadding)
353  {
354  // pad both M and K
355  const auto a_grid_desc_m_k =
356  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
358  make_right_pad_transform(K, KPad - K)),
361 
362  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
363  a_grid_desc_m_k,
368 
369  return a_grid_desc_ak0_m_ak1;
370  }
371  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
372  GemmSpec == GemmSpecialization::MNPadding)
373  {
374  // pad M, but not K
375  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
376  a_grid_desc_mraw_kraw,
378  make_right_pad_transform(M, MPad - M)),
381 
382  return a_grid_desc_ak0_m_ak1;
383  }
384  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
385  GemmSpec == GemmSpecialization::NKPadding)
386  {
387  // pad K, but not M
388  const auto a_grid_desc_m_k = transform_tensor_descriptor(
389  a_grid_desc_mraw_kraw,
393 
394  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
395  a_grid_desc_m_k,
400 
401  return a_grid_desc_ak0_m_ak1;
402  }
403  else
404  {
405  // not pad M or K
406  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
407  a_grid_desc_mraw_kraw,
412 
413  return a_grid_desc_ak0_m_ak1;
414  }
415  }
416 
417  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
418  {
419  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
420  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
421  constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack / KGroup>{};
423  make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
424  make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
425  }
426 
427  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
428  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
429  {
430  const auto b_grid_desc_nraw_kraw = [&]() {
432  {
433  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
434  }
436  {
437  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
438  }
439  }();
440 
442 
443  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
444  GemmSpec != GemmSpecialization::Default),
445  "pk_i4_t does not support padding");
446 
447  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
448  GemmSpec == GemmSpecialization::MNKPadding)
449  {
450  // pad both N and K
451  const auto b_grid_desc_n_k =
452  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
454  make_right_pad_transform(K, KPad - K)),
457 
458  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
459  b_grid_desc_n_k,
464 
465  return b_grid_desc_bk0_n_bk1;
466  }
467  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
468  GemmSpec == GemmSpecialization::MNPadding)
469  {
470  // pad N, but not K
471  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
472  b_grid_desc_nraw_kraw,
474  make_right_pad_transform(N, NPad - N)),
477 
478  return b_grid_desc_bk0_n_bk1;
479  }
480  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
481  GemmSpec == GemmSpecialization::MKPadding)
482  {
483  // pad K, but not N
484  const auto b_grid_desc_n_k = transform_tensor_descriptor(
485  b_grid_desc_nraw_kraw,
489 
490  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
491  b_grid_desc_n_k,
496 
497  return b_grid_desc_bk0_n_bk1;
498  }
499  else
500  {
501  // not pad N or K
502  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
503  b_grid_desc_nraw_kraw,
508 
509  return b_grid_desc_bk0_n_bk1;
510  }
511  }
512 
513  template <typename ABlockDesc_AK0_M_AK1>
514  __host__ __device__ static constexpr auto
515  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
516  {
517  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
518 
519  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
520  }
521 
522  template <typename BBlockDesc_BK0_N_BK1>
523  __host__ __device__ static constexpr auto
524  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
525  {
526  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
527  }
528 
529  template <typename ELayout>
530  __host__ __device__ static auto MakeCGridDescriptor_M_N(
531  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
532  {
533  const auto c_grid_desc_mraw_nraw = [&]() {
535  {
536  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
537  }
539  {
540  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
541  }
542  }();
543 
544  // pad M and N
545  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
547  make_right_pad_transform(N, NPad - N)),
550  }
551 
552  template <typename DLayout>
553  __host__ __device__ static auto
555  {
556  const auto c_grid_desc_mraw_nraw = [&]() {
558  {
559  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
560  }
562  {
563  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
564  }
565  }();
566 
567  // pad M and N
568  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
570  make_right_pad_transform(N, NPad - N)),
573  }
574 
575  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
576  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
577  {
578  return generate_tuple(
579  [&](auto i) {
580  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
581  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
582  },
584  }
585 
586  template <typename DsGridDesc>
588  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
589  {
590  return generate_tuple(
591  [&](auto i) {
593  ds_grid_desc_m_n[i], MBlock, NBlock);
594  },
596  }
597 
598  using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
599 
600  struct Problem
601  {
602  __host__ __device__ Problem(index_t NumTokens_,
603  index_t TopK_,
604  index_t M_,
605  index_t N_,
606  index_t K_,
607  index_t StrideA_,
608  index_t StrideB_,
609  std::array<index_t, NumDTensor> StrideDs_,
610  index_t StrideC_,
611  index_t KBatch_)
612  : NumTokens{NumTokens_},
613  TopK{TopK_},
614  M{M_},
615  N{N_},
616  K{K_},
617  StrideA{StrideA_},
618  StrideB{StrideB_},
619  StrideDs{StrideDs_},
620  StrideC{StrideC_},
621  KBatch{KBatch_},
624  KRead{CalculateKRead(K_, KBatch_)},
625  KPadded{CalculateKPadded(K_, KBatch_)},
626  AK0{CalculateAK0Padded(K_, KBatch_)},
627  BK0{CalculateBK0Padded(K_, KBatch_)},
628  MBlock{CalculateMBlock(M_)},
630  {
631  }
632 
633  __host__ void Print() const
634  {
635  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
636  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
637  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
638  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
639  << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
640  << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
641  << "NBlock: " << NBlock << "}" << std::endl;
642  }
643 
651  std::array<index_t, NumDTensor> StrideDs;
662  };
663 
664  // Argument
666  {
667  __host__ Argument(const index_t* p_sorted_token_ids_,
668  const index_t* p_sorted_expert_ids_,
669  const index_t* p_max_token_id_,
670  const ADataType* p_a_grid_,
671  const BDataType* p_b_grid_,
672  std::array<const void*, NumDTensor> p_ds_grid_,
673  CDataType* p_c_grid_,
674  index_t NumTokens_,
675  index_t TopK_,
676  index_t M_,
677  index_t N_,
678  index_t K_,
679  index_t StrideA_,
680  index_t StrideB_,
681  std::array<index_t, NumDTensor> StrideDs_,
682  index_t StrideC_,
683  const AScaleType* p_a_scale_grid_,
684  const BScaleType* p_b_scale_grid_,
685  index_t k_batch_,
686  AElementwiseOperation a_element_op_,
687  BElementwiseOperation b_element_op_,
688  CElementwiseOperation c_element_op_)
689  : Problem{NumTokens_,
690  TopK_,
691  M_,
692  N_,
693  K_,
694  StrideA_,
695  StrideB_,
696  StrideDs_,
697  StrideC_,
698  k_batch_},
699  p_sorted_token_ids{p_sorted_token_ids_},
700  p_sorted_expert_ids{p_sorted_expert_ids_},
701  p_max_token_id{p_max_token_id_},
702  p_a_grid{p_a_grid_},
703  p_b_grid{p_b_grid_},
704  p_ds_grid{},
705  p_c_grid{p_c_grid_},
706  p_a_scale_grid{p_a_scale_grid_},
707  p_b_scale_grid{p_b_scale_grid_},
708  a_element_op{a_element_op_},
709  b_element_op{b_element_op_},
710  c_element_op{c_element_op_}
711  {
712 
713  // populate pointer, desc for Ds
714  static_for<0, NumDTensor, 1>{}([&](auto i) {
715  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
716 
717  // D pointer
718  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
719  });
720  }
721 
725  const ADataType* p_a_grid;
726  const BDataType* p_b_grid;
728  CDataType* p_c_grid;
729 
732 
733  const AElementwiseOperation a_element_op;
734  const BElementwiseOperation b_element_op;
735  const CElementwiseOperation c_element_op;
736  };
737 
739  {
740  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
741  {
742  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
743  {
744  a_k_split_offset = k_id * karg.KRead / APackedSize;
745  }
746  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
747  {
748  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
749  }
750 
751  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
752  {
753  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
754  }
755  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
756  {
757  // KPack * NLane * KLane * K0 * N0
758  b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
759  }
760 
761  if(k_id < karg.KBatch - 1)
762  {
763  karg.K = karg.KRead;
764  }
765  else
766  {
767  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
768  }
769  }
770 
773  };
774 
775  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
776  {
777  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
778  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
779  // A matrix in LDS memory, dst of blockwise copy
780  if constexpr(ABlockLdsExtraM)
781  {
785  }
786  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
787  // in some cases.
789  {
790  constexpr auto a_lds_block_desc =
793 
794  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
795  a_lds_block_desc,
801 
802  return a_lds_block_desc_permuted;
803  }
804  else // ColumnMajor A
805  {
806  // kfold and mpair dimension is not always required.
807  // more dimension in merge_transform increase the difficulty of generating immarg offset
808  // for compiler.
809  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
810  constexpr auto M1 = MPerBlock / M0;
811 
812  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
813  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
814  constexpr auto KThreadRead = WaveSize / MPerXdl;
815  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
816 
817  constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
818  ? 1
819  : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
820  constexpr auto KThreadReadPerm =
821  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
822  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
823  : KThreadRead;
824 
825  // 1<=mpair<=n0
826  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
827  ? 1
828  : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
829  ? M0
830  : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
831 
832  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
836  Number<kfold * M0 / mpair>{},
837  Number<mpair>{},
838  AK1Number));
839 
840  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
841  a_lds_block_desc,
842  make_tuple(
846  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
849  make_tuple(
851  make_tuple(
853 
854  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
855  a_lds_block_desc_permuted,
856  make_tuple(
864  Sequence<1>{},
865  Sequence<2>{},
866  Sequence<3>{},
867  Sequence<4>{},
868  Sequence<5>{}),
870  Sequence<2>{},
871  Sequence<0, 3>{},
872  Sequence<4, 5>{},
873  Sequence<6>{},
874  Sequence<7>{}));
875 
876  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
877  a_lds_block_desc_unmerged,
880  Number<KThreadWrite / kfold / KThreadReadPerm>{},
881  Number<kfold>{},
888 
889  return a_lds_block_desc_ak0_m_ak1;
890  }
891  }
892 
893  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
894  {
895  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
898  }
899 
901  {
902  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
903 
904  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
906  make_tuple(I1,
908  I1,
910 
911  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
912  }
913 
916  BlkGemmPipelineVer,
917  BlkGemmPipeSched,
918  BlockSize,
919  ADataType,
920  BDataType,
921  ComputeTypeA,
922  AccDataType,
929  ABlockTransferSrcScalarPerVector,
930  BBlockTransferSrcScalarPerVector,
931  MPerBlock,
932  NPerBlock,
933  KPerBlock,
934  ScaleBlockM,
935  ScaleBlockN,
936  ScaleBlockK,
937  MPerXdl,
938  NPerXdl,
939  MXdlPerWave,
940  NXdlPerWave,
941  KPack,
942  IsInputGemm>())>;
943 
944  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
945  {
946  // LDS allocation for A and B: be careful of alignment
947  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
948  // lds max alignment
949  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
950 
951  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
952  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
953 
954  // LDS allocation for C shuffle in LDS
955  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
957 
958  constexpr auto c_block_size =
959  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
960 
961  return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize,
962  c_block_size * sizeof(CShuffleDataType));
963  }
964 
966 
967  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
968  __host__ static constexpr bool CheckValidity(const Argument& karg)
969  {
970  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
971  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
972  "Invalid tuning param!");
973 
974  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
979  {
980  if(!(karg.M % MPerBlock == 0))
981  {
982 #if DEBUG_LOG
983  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
984  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
985  << std::endl;
986 
987 #endif // DEBUG_LOG
988  return false;
989  }
990  }
991 
992  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
997  {
998  if(!(karg.N % NPerBlock == 0))
999  {
1000 #if DEBUG_LOG
1001  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1002  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1003  << std::endl;
1004 
1005 #endif // DEBUG_LOG
1006  return false;
1007  }
1008  }
1009 
1010  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1014  {
1015 
1016  auto K_t = karg.KBatch * KPerBlock;
1017  if(!(karg.K % K_t == 0))
1018  {
1019 #if DEBUG_LOG
1020  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1021  << karg.K << " " << __FILE__ << ":" << __LINE__
1022  << ", in function: " << __func__ << std::endl;
1023 
1024 #endif // DEBUG_LOG
1025  return false;
1026  }
1027  }
1028  else
1029  {
1030  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1031  auto K_t = karg.KBatch * KReadVec;
1032  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1033  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1034  {
1035  return false;
1036  }
1037  }
1038 
1040  {
1041  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1042  {
1043 #if DEBUG_LOG
1044  std::cout << "Arg K (" << karg.K
1045  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1046  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1047  << __LINE__ << ", in function: " << __func__ << std::endl;
1048 
1049 #endif // DEBUG_LOG
1050  return false;
1051  }
1052  }
1053  else
1054  {
1055  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1056  {
1057 #if DEBUG_LOG
1058  std::cout << "Arg M (" << karg.M
1059  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1060  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1061  << __LINE__ << ", in function: " << __func__ << std::endl;
1062 
1063 #endif // DEBUG_LOG
1064  return false;
1065  }
1066  }
1067 
1069  {
1070  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1071  {
1072 #if DEBUG_LOG
1073  std::cout << "Arg N (" << karg.N
1074  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1075  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1076  << __LINE__ << ", in function: " << __func__ << std::endl;
1077 
1078 #endif // DEBUG_LOG
1079  return false;
1080  }
1081  }
1082  else
1083  {
1084  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1085  {
1086 #if DEBUG_LOG
1087  std::cout << "Arg K (" << karg.K
1088  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1089  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1090  << __LINE__ << ", in function: " << __func__ << std::endl;
1091 
1092 #endif // DEBUG_LOG
1093  return false;
1094  }
1095  }
1096 
1098  {
1100  {
1101 #if DEBUG_LOG
1102  std::cout << "Arg N (" << karg.N
1103  << ") value is not a multiple of "
1104  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1105  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1106  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1107 
1108 #endif // DEBUG_LOG
1109  return false;
1110  }
1111  }
1112  else
1113  {
1115  {
1116 #if DEBUG_LOG
1117  std::cout << "Arg M (" << karg.M
1118  << ") value is not a multiple of "
1119  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1120  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1121  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1122 
1123 #endif // DEBUG_LOG
1124  return false;
1125  }
1126  }
1127 
1128  // check gridwise gemm pipeline
1129 #if 0
1130  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1131 
1132  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1133  {
1134  return false;
1135  }
1136 #endif
1137  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1138  return true;
1139  }
1140 
1141  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1142  {
1143  const index_t num_loop = K / KPerBlock;
1144 
1145  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1146  }
1147 
1148  __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1149  {
1150  const index_t num_loop = K / KPerBlock;
1151 
1152  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1153  }
1154 
1155  template <typename CGridDesc>
1157  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1158  {
1159  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1160  c_grid_desc_m_n,
1165 
1166  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1167  }
1168 
1169  // return block_id to C matrix tile idx (m0, n0) mapping
1170  // if arch = gfx942
1171  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1172  // NPerBlock>;
1173 
1174  template <bool HasMainKBlockLoop,
1175  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1176  TailNumber TailNum = TailNumber::Odd>
1177  __device__ static void Run(const index_t* p_sorted_token_ids,
1178  const index_t* p_sorted_expert_ids,
1179  const index_t* p_max_token_id,
1180  const ADataType* p_a_grid,
1181  const BDataType* p_b_grid,
1182  DsGridPointer& p_ds_grid,
1183  CDataType* p_c_grid,
1184  const AScaleType* p_a_scale_grid,
1185  const BScaleType* p_b_scale_grid,
1186  void* p_shared,
1187  const Problem& problem,
1188  AElementwiseOperation a_element_op,
1189  BElementwiseOperation b_element_op,
1190  CElementwiseOperation c_element_op)
1191  {
1192  ignore = b_element_op;
1193  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1194  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1195  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1196  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1197  problem.MPadded,
1198  problem.K,
1199  problem.KPadded,
1200  problem.StrideA,
1201  problem.AK0);
1202  const auto b_grid_desc_bpreshuffled =
1203  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1204  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1205  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1206  problem.MPadded,
1207  problem.N,
1208  problem.NPadded,
1209  problem.StrideC);
1210 
1211  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1212  make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
1213  : problem.NumTokens * problem.TopK,
1214  ScaleBlockM),
1215  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1216  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1217  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1218  make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1219  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1220  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1221 
1222  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1224  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1225  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1226  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1227  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1228  if(expert_block_id * MPerBlock >= max_token_id)
1229  return;
1230  const index_t expert_id =
1231  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1232  const auto block_mn = [&]() -> std::pair<int, int> {
1233  if constexpr(NSwizzle)
1234  {
1235  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1236  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1237  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1238  const index_t expert_swizzle =
1239  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1240  const index_t bid_new = blockIdx.x - prefix_block;
1241  const index_t nid = __builtin_amdgcn_readfirstlane(
1242  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1243  const index_t mid =
1244  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1245  return {nid, mid};
1246  }
1247  else
1248  {
1249  return {blockIdx.x, blockIdx.y};
1250  }
1251  }();
1252  const index_t block_n_id = block_mn.first;
1253  const index_t block_m_id = block_mn.second;
1254  const index_t token0 =
1255  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1256 
1257  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1258  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1259  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1260  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1261  constexpr auto AKThreads = AK0Threads * AK1Threads;
1262  constexpr auto AMRepeats = MPerBlock / AMThreads;
1263  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1264 
1265  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1266  return;
1268  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1269  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1270  index_t token_offset = fused_token & 0xffffff;
1271  if constexpr(!IsInputGemm)
1272  {
1273  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1274  }
1275  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1276  });
1277  const index_t expert_stride =
1278  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1279  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1280  math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
1281  math::integer_divide_ceil(problem.K, ScaleBlockK));
1282 
1283  // N0, K0, Blocksize*KPack
1284  const index_t n_block_data_idx_on_grid =
1285  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1286 
1287  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1288  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1289  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1290  p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1291  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1292 
1293  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1294  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1295  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1296  p_b_scale_grid + expert_id * expert_scale_stride,
1297  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1298 
1299  // A matrix in LDS memory, dst of blockwise copy
1300  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1301 
1302  // B matrix in LDS memory, dst of blockwise copy
1303  // dummy
1304  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1305  // A matrix blockwise copy
1306  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1308  AElementwiseOperation,
1312  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1313  ABlockTransferThreadClusterArrangeOrder,
1314  ADataType,
1315  LDSTypeA,
1316  decltype(a_grid_desc_ak0_m_ak1),
1317  decltype(a_block_desc_ak0_m_ak1),
1318  ABlockTransferSrcAccessOrder,
1320  ABlockTransferSrcVectorDim,
1321  2,
1322  ABlockTransferSrcScalarPerVector,
1323  ABlockTransferDstScalarPerVector_AK1,
1324  1,
1325  1,
1326  AThreadTransferSrcResetCoordinateAfterRun,
1327  true,
1328  IndexType,
1329  1,
1330  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1331  make_multi_index(0, 0, 0),
1332  a_element_op,
1333  a_block_desc_ak0_m_ak1,
1334  make_multi_index(0, 0, 0),
1336  gather_offsets);
1337 
1338  // Thread-wise copy
1339  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1340  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1341  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1342 
1343  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1344  BDataType,
1345  BDataType,
1346  decltype(b_grid_desc_bpreshuffled),
1347  decltype(b_block_desc_bk0_n_bk1),
1350  3,
1351  BBlockTransferSrcScalarPerVector,
1352  BThreadTransferSrcResetCoordinateAfterRun,
1353  true>(b_grid_desc_bpreshuffled,
1354  make_multi_index(n_block_data_idx_on_grid,
1356  0,
1357  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1358 
1359  // LDS allocation for A and B: be careful of alignment
1360  // Cast after lds
1361  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1362  static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1363 
1364  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1365  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1366 
1367  // Blockwise GEMM pipeline
1368  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1369  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1370  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1371  decltype(c_thread_buf) c_thread_buf_up;
1372 
1373  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1374  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1375  KPerBlock);
1376 
1377  constexpr index_t ScaleSliceSizeM = MXdlPerWave;
1378  constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
1379  constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
1380 
1381  // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
1382  // ScaleSliceSizeK is first dimension in C scale for packed math
1383  constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
1385 
1386  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1387  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1388  constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
1389  auto a_thread_offset = get_thread_local_1d_id() % MPerXdl +
1390  (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl;
1391 
1392  constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1394 
1395  constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
1397 
1398  // get each thread's offset in the scale tensor
1399  // A scale
1400  const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
1401 
1402  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
1403  return;
1404  StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
1405  static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
1406  const index_t fused_token =
1407  p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
1408  index_t token_offset = fused_token & 0xffffff;
1409  if constexpr(!IsInputGemm)
1410  {
1411  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1412  }
1413  scale_gather_offsets(m0) =
1414  token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK);
1415  });
1416 
1417  auto a_scale_thread_copy =
1419  AScaleType,
1420  decltype(a_scale_grid_desc_am_ak),
1421  decltype(a_scale_thread_desc),
1424  1,
1425  ScaleSliceSizeK,
1426  1,
1427  false,
1428  MXdlPerWave>(
1429  a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
1430 
1431  auto b_scale_thread_copy =
1433  BScaleType,
1434  decltype(b_scale_grid_desc_bn_ak),
1435  decltype(b_scale_thread_desc),
1438  1,
1439  ScaleSliceSizeK,
1440  1,
1441  false>(
1442  b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1443 
1444  // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
1445  constexpr auto a_scale_thread_slice_copy_step =
1446  make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
1447  constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
1448 
1449  constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
1450  if constexpr(IsInputGemm)
1451  {
1452  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1453  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1454  p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1455  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1456  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1457  BDataType,
1458  BDataType,
1459  decltype(b_grid_desc_bpreshuffled),
1460  decltype(b_block_desc_bk0_n_bk1),
1463  3,
1464  BBlockTransferSrcScalarPerVector,
1465  BThreadTransferSrcResetCoordinateAfterRun,
1466  true>(b_grid_desc_bpreshuffled,
1467  make_multi_index(n_block_data_idx_on_grid,
1469  0,
1470  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1471  const BScaleType* p_b_scale_grid_up =
1472  p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
1473  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1474  p_b_scale_grid_up + expert_id * expert_scale_stride,
1475  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1476  auto b_scale_thread_copy_up =
1478  BScaleType,
1479  decltype(b_scale_grid_desc_bn_ak),
1480  decltype(b_scale_thread_desc),
1483  1,
1484  ScaleSliceSizeK,
1485  1,
1486  false>(
1487  b_scale_grid_desc_bn_ak,
1488  make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1489 
1490  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1491  a_grid_desc_ak0_m_ak1,
1492  a_block_desc_ak0_m_ak1,
1493  a_blockwise_copy,
1494  a_grid_buf,
1495  a_block_buf,
1496  a_block_slice_copy_step,
1497 
1498  b_grid_desc_bpreshuffled,
1499  b_block_desc_bk0_n_bk1,
1500  b_blockwise_copy,
1501  b_blockwise_copy_up,
1502  b_grid_buf,
1503  b_grid_buf_up,
1504  b_block_buf,
1505  b_block_slice_copy_step,
1506 
1507  c_scale_thread_desc,
1508  c_thread_buf,
1509  c_thread_buf_up,
1510 
1511  a_scale_grid_desc_am_ak,
1512  a_scale_thread_desc,
1513  a_scale_thread_copy,
1514  a_scale_grid_buf,
1515  a_scale_thread_slice_copy_step,
1516 
1517  b_scale_grid_desc_bn_ak,
1518  b_scale_thread_desc,
1519  b_scale_thread_copy,
1520  b_scale_thread_copy_up,
1521  b_scale_grid_buf,
1522  b_scale_grid_buf_up,
1523  b_scale_thread_slice_copy_step,
1524 
1525  num_k_block_main_loop);
1526  }
1527  else
1528  {
1529  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1530  a_grid_desc_ak0_m_ak1,
1531  a_block_desc_ak0_m_ak1,
1532  a_blockwise_copy,
1533  a_grid_buf,
1534  a_block_buf,
1535  a_block_slice_copy_step,
1536 
1537  b_grid_desc_bpreshuffled,
1538  b_block_desc_bk0_n_bk1,
1539  b_blockwise_copy,
1540  b_grid_buf,
1541  b_block_buf,
1542  b_block_slice_copy_step,
1543 
1544  c_scale_thread_desc,
1545  c_thread_buf,
1546 
1547  a_scale_grid_desc_am_ak,
1548  a_scale_thread_desc,
1549  a_scale_thread_copy,
1550  a_scale_grid_buf,
1551  a_scale_thread_slice_copy_step,
1552 
1553  b_scale_grid_desc_bn_ak,
1554  b_scale_thread_desc,
1555  b_scale_thread_copy,
1556  b_scale_grid_buf,
1557  b_scale_thread_slice_copy_step,
1558 
1559  num_k_block_main_loop);
1560  }
1561 
1562  // shuffle C and write out
1563  {
1564  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1565  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1566  "wrong!");
1567 
1568  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1569 
1570  // transposed XDL
1571  // TODO: hacky, fix it!
1572  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1573  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1574 
1575  // TODO: hacky, fix it!
1576  // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
1577  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1578  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1579 
1580  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
1581  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
1582  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
1583  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
1584  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
1585  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
1586  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
1587  constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
1588 
1589  static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
1590  static_assert(M0 * M1 * M2 == MPerBlock);
1591  static_assert(N4 == 4 || N4 == 8);
1592  const index_t m1 = get_warp_local_1d_id() / NWave;
1593  const index_t m2 = threadIdx.x % get_warp_size() % M2;
1594 
1595  float topk_weight;
1596  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1597  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1598  if constexpr(MulRoutedWeight)
1599  {
1600  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
1601  topk_weight = p_ds_grid[I0][m_pos];
1602  }
1603  static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
1604  static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
1605  constexpr index_t c_offset =
1606  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1607  make_tuple(m0, n0, n2 * N4 + n4));
1608  constexpr auto cidx = Number<c_offset>{};
1609  if constexpr(IsInputGemm) // gu fusion, elementwise
1610  {
1611  if constexpr(ActivationOperation == Activation::silu_and_mul)
1612  {
1613  float gate = c_thread_buf[cidx];
1614  float up = c_thread_buf_up[cidx];
1615  if constexpr(MulRoutedWeight)
1616  {
1617  gate = gate * topk_weight;
1618  up = up * topk_weight;
1619  }
1621  {
1622  gate *= 16;
1623  up *= 16;
1624  }
1626  c_thread_buf(cidx) = gate * up;
1627  }
1628  else if(ActivationOperation == Activation::gelu_and_mul)
1629  {
1630  float gate = c_thread_buf[cidx];
1631  float up = c_thread_buf_up[cidx];
1632  if constexpr(MulRoutedWeight)
1633  {
1634  gate = gate * topk_weight;
1635  up = up * topk_weight;
1636  }
1638  {
1639  gate *= 16;
1640  up *= 16;
1641  }
1643  c_thread_buf(cidx) = gate * up;
1644  }
1645  }
1646  else
1647  {
1648  if constexpr(MulRoutedWeight)
1649  {
1650  c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
1651  }
1652  }
1653  });
1654  });
1655  });
1656  });
1657 
1658  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1660 
1661  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1662  static_cast<CShuffleDataType*>(p_shared),
1663  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1664 
1665  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
1666  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1667  make_tuple(
1670  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1671  M1, // M1 = MWave
1672  M2)), // M2 = MPerXdl
1675  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1676  N1, // N1 = NWave
1677  N2, // N2 * N3 * N4 = NPerXdl
1678  N3,
1679  N4))),
1681  make_tuple(
1683 
1684  // calculate origin of thread output tensor on global memory
1685  // blockwise GEMM c matrix starting index
1686  const auto c_thread_mtx_on_block =
1687  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1688 
1689  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1690  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1691 
1692  const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1696  make_tuple(Sequence<0>{}));
1697 
1698  const auto m_thread_data_on_block_idx =
1699  m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1700  make_multi_index(m_thread_data_on_block));
1701 
1702  const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1704  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
1706  make_tuple(Sequence<0>{}));
1707 
1708  const auto n_thread_data_on_block_idx =
1709  n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1710  make_multi_index(n_thread_data_on_block));
1711 
1712  // shuffle: threadwise copy C from VGPR to LDS
1713  auto c_thread_copy_vgpr_to_lds =
1715  CShuffleDataType,
1716  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1717  decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1719  Sequence<CShuffleMXdlPerWavePerShuffle,
1720  CShuffleNXdlPerWavePerShuffle,
1721  I1,
1722  I1,
1723  I1,
1724  N2,
1725  I1,
1726  N4>,
1728  7,
1729  1,
1731  1,
1732  true>{
1733  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1734  make_multi_index(0,
1735  0,
1736  m_thread_data_on_block_idx[I1],
1737  n_thread_data_on_block_idx[I1],
1738  m_thread_data_on_block_idx[I2],
1739  n_thread_data_on_block_idx[I2],
1740  n_thread_data_on_block_idx[I3],
1741  n_thread_data_on_block_idx[I4]),
1743 
1744  using EDataType = CDataType;
1745 
1746  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1747  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1748 
1749  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1751  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1752 
1753  const auto ds_grid_buf = generate_tuple(
1754  [&](auto i) {
1755  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
1756  const DDataType* ptr_ = p_ds_grid[i];
1757  // hack logic here to support different kind of strides. todo fix it.
1758  // ascale t, 1; bscale E, N, 1, move ptr to E
1759  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1760  ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
1761  },
1762  Number<NumDTensor>{});
1763 
1764  // tuple of reference to C/Ds tensor descriptors
1765  const auto c_ds_desc_refs = concat_tuple_of_reference(
1766  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1767  generate_tie([&](auto i) -> const auto& // return type should be reference
1768  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1769  Number<NumDTensor>{}));
1770 
1771  // tuple of reference to C/Ds tensor descriptors
1772  const auto c_ds_buf_refs = concat_tuple_of_reference(
1773  tie(c_shuffle_block_buf),
1774  generate_tie([&](auto i) -> const auto& // return type should be reference
1775  { return ds_grid_buf[i]; },
1776  Number<NumDTensor>{}));
1777 
1778  // tuple of starting index of C/Ds blockwise copy
1779  const auto idx_c_ds_block_begin =
1782  [&](auto) {
1783  return make_multi_index(block_m_id, 0, block_n_id, 0);
1784  // return make_multi_index(block_work_idx[I0], 0,
1785  // block_work_idx[I1], 0);
1786  },
1787  Number<NumDTensor>{}));
1788 
1789  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1790  c_grid_desc_mblock_mperblock_nblock_nperblock;
1791 
1792  using CDEBlockTransferCluster =
1793  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1794  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1795  constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
1796  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1798  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1800  decltype(c_ds_desc_refs),
1801  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1802  CElementwiseOperation,
1803  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1804  // support arbitray type
1805  Sequence<1,
1806  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1807  1,
1808  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1809  CDEBlockTransferCluster,
1810  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1811  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1812  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1813  3, // index_t SrcVectorDim,
1814  3, // index_t DstVectorDim,
1815  CDEShuffleBlockTransferScalarPerVectors,
1820  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1821  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1822  IndexType,
1823  1, // ScatterDim
1824  true, // OutputScatter: false, only use scatter weights
1825  scatter_weight_idx // ScatterWeightIdx: ascale
1826  >{c_ds_desc_refs,
1827  idx_c_ds_block_begin,
1828  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1829  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1830  c_element_op};
1831 
1832  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1833  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1834  // space filling curve for threadwise C in VGPR
1835  constexpr auto sfc_c_vgpr =
1838  Sequence<CShuffleMXdlPerWavePerShuffle,
1839  CShuffleNXdlPerWavePerShuffle,
1840  1,
1841  1,
1842  1,
1843  N2,
1844  1,
1845  N4>>{};
1846 
1847  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1848 
1849  // space filling curve for shuffled blockwise C/D/E
1850  constexpr auto sfc_cde_block =
1853  Sequence<1,
1854  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1855  1,
1856  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1857 
1858  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1859  constexpr auto EMThreads =
1860  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1861  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1862  constexpr auto ENThreads =
1863  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1864  static_for<0, num_access, 1>{}([&](auto access_id) {
1865  // make sure it's safe to write to LDS
1867 
1868  auto dstidx = sfc_cde_block.GetIndex(access_id);
1869  const index_t c_token_pos =
1870  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1871  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1872  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1873  index_t token_offset = fused_token & 0xffffff;
1874  if constexpr(IsInputGemm)
1875  {
1876  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1877  }
1878  scatter_offsets(m0) = token_offset * problem.N;
1879  });
1880 
1881  block_sync_lds();
1882 
1883  // each thread write its data from VGPR to LDS
1884  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1885  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1886  c_thread_buf,
1887  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1888  c_shuffle_block_buf);
1889 
1890  // make sure it's safe to read from LDS
1891  block_sync_lds();
1892 
1893  // each block copy its data from LDS to global
1894  cde_block_copy_lds_and_global.Run(
1895  c_ds_desc_refs,
1896  c_ds_buf_refs,
1897  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1898  tie(c_grid_buf),
1899  scatter_offsets);
1900 
1901  if constexpr(access_id < num_access - 1)
1902  {
1903  constexpr auto cde_lds_and_global_step =
1904  sfc_cde_block.GetForwardStep(access_id);
1905 
1906  // move on Ds
1907  static_for<0, NumDTensor, 1>{}([&](auto i) {
1908  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1909  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1910  });
1911 
1912  // move on E
1913  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1914  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1915  I0,
1916  cde_lds_and_global_step);
1917  }
1918  });
1919  }
1920  }
1921 
1922  template <bool HasMainKBlockLoop,
1923  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1924  TailNumber TailNum = TailNumber::Odd>
1925  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1926  const index_t* p_sorted_expert_ids,
1927  const index_t* p_max_token_id,
1928  const ADataType* p_a_grid,
1929  const BDataType* p_b_grid,
1930  DsGridPointer& p_ds_grid,
1931  CDataType* p_c_grid,
1932  const AScaleType* p_a_scale_grid,
1933  const BScaleType* p_b_scale_grid,
1934  void* p_shared,
1935  void* p_shared1,
1936  const Problem& problem,
1937  AElementwiseOperation a_element_op,
1938  BElementwiseOperation b_element_op,
1939  CElementwiseOperation c_element_op)
1940  {
1941  ignore = b_element_op;
1942  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1943  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1944  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1945  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1946  problem.MPadded,
1947  problem.K,
1948  problem.KPadded,
1949  problem.StrideA,
1950  problem.AK0);
1951  const auto b_grid_desc_bpreshuffled =
1952  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1953  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1954  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1955  problem.MPadded,
1956  problem.N,
1957  problem.NPadded,
1958  problem.StrideC);
1959 
1960  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1961  make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
1962  : problem.NumTokens * problem.TopK,
1963  ScaleBlockM),
1964  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1965  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1966  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1967  make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1968  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1969  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1970  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1972  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1973  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1974  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1975  if(expert_block_id * MPerBlock >= max_token_id)
1976  return;
1977  const index_t expert_id =
1978  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1979  const auto block_mn = [&]() -> std::pair<int, int> {
1980  if constexpr(NSwizzle)
1981  {
1982  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1983  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1984  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1985  const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
1986  const index_t bid_new = blockIdx.x - prefix_block;
1987  const index_t nid = __builtin_amdgcn_readfirstlane(
1988  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1989  const index_t mid =
1990  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1991  return {nid, mid};
1992  }
1993  else
1994  {
1995  return {blockIdx.x, blockIdx.y};
1996  }
1997  }();
1998  const index_t block_n_id = block_mn.first;
1999  const index_t block_m_id = block_mn.second;
2000 
2001  const index_t token0 =
2002  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2003 
2004  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2005  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2006  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2007  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2008  constexpr auto AKThreads = AK0Threads * AK1Threads;
2009  constexpr auto AMRepeats = MPerBlock / AMThreads;
2010  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2011 
2012  if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
2013  token0 >= problem.NumTokens)
2014  return;
2016  gather_offsets; //= p_sorted_token_ids[token_pos];
2017  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2018  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2019  index_t token_offset = fused_token & 0xffffff;
2020  if constexpr(!IsInputGemm)
2021  {
2022  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2023  }
2024  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2025  });
2026  const index_t expert_stride =
2027  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2028  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2029  math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
2030  math::integer_divide_ceil(problem.K, ScaleBlockK));
2031  // N0, K0, Blocksize*KPack
2032  const index_t n_block_data_idx_on_grid =
2033  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2034 
2035  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2036  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2037  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2038  p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2039  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2040 
2041  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2042  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2043  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2044  p_b_scale_grid + expert_id * expert_scale_stride,
2045  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2046 
2047  // A matrix in LDS memory, dst of blockwise copy
2048  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2049 
2050  // B matrix in LDS memory, dst of blockwise copy
2051  // dummy
2052  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2053  // A matrix blockwise copy
2054  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2056  AElementwiseOperation,
2060  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2061  ABlockTransferThreadClusterArrangeOrder,
2062  ADataType,
2063  LDSTypeA,
2064  decltype(a_grid_desc_ak0_m_ak1),
2065  decltype(a_block_desc_ak0_m_ak1),
2066  ABlockTransferSrcAccessOrder,
2068  ABlockTransferSrcVectorDim,
2069  2,
2070  ABlockTransferSrcScalarPerVector,
2071  ABlockTransferDstScalarPerVector_AK1,
2072  1,
2073  1,
2074  AThreadTransferSrcResetCoordinateAfterRun,
2075  true,
2076  IndexType,
2077  1,
2078  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2079  make_multi_index(0, 0, 0),
2080  a_element_op,
2081  a_block_desc_ak0_m_ak1,
2082  make_multi_index(0, 0, 0),
2084  gather_offsets);
2085 
2086  // Thread-wise copy
2087  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2088  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2089  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2090  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2091  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2092  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2093 
2094  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
2095  BDataType,
2096  BDataType,
2097  decltype(b_grid_desc_bpreshuffled),
2098  decltype(b_block_desc_bk0_n_bk1),
2101  3,
2102  BBlockTransferSrcScalarPerVector,
2103  BThreadTransferSrcResetCoordinateAfterRun,
2104  true>(b_grid_desc_bpreshuffled,
2105  make_multi_index(n_block_data_idx_on_grid,
2107  0,
2108  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2109 
2110  // LDS allocation for A and B: be careful of alignment
2111  // Cast after lds
2112  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2113  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2114  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2115  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2116  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2117 
2118  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2119  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
2120 
2121  // Blockwise GEMM pipeline
2122  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2123  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2124  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2125  decltype(c_thread_buf) c_thread_buf_up;
2126 
2127  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2128  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2129  KPerBlock);
2130 
2131  // scale
2132  constexpr index_t ScaleSliceSizeM = MXdlPerWave;
2133  constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
2134  constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
2135 
2136  // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
2137  // ScaleSliceSizeK is first dimension in C scale for packed math
2138  constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
2140 
2141  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
2142  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
2143  constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
2144  auto a_thread_offset = get_thread_local_1d_id() % MPerXdl +
2145  (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl;
2146 
2147  constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
2149 
2150  constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
2152 
2153  // get each thread's offset in the scale tensor
2154  // A scale
2155  const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
2156 
2157  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2158  return;
2159  StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
2160  static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
2161  const index_t fused_token =
2162  p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
2163  index_t token_offset = fused_token & 0xffffff;
2164  if constexpr(!IsInputGemm)
2165  {
2166  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2167  }
2168  scale_gather_offsets(m0) = static_cast<IndexType>(token_offset) *
2169  math::integer_divide_ceil(problem.K, ScaleBlockK);
2170  });
2171 
2172  auto a_scale_thread_copy =
2174  AScaleType,
2175  decltype(a_scale_grid_desc_am_ak),
2176  decltype(a_scale_thread_desc),
2179  1,
2180  ScaleSliceSizeK,
2181  1,
2182  false,
2183  MXdlPerWave>(
2184  a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
2185 
2186  auto b_scale_thread_copy =
2188  BScaleType,
2189  decltype(b_scale_grid_desc_bn_ak),
2190  decltype(b_scale_thread_desc),
2193  1,
2194  ScaleSliceSizeK,
2195  1,
2196  false>(
2197  b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2198 
2199  // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
2200  constexpr auto a_scale_thread_slice_copy_step =
2201  make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
2202  constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
2203 
2204  constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
2205  if constexpr(IsInputGemm)
2206  {
2207  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2208  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2209  p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2210  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2211  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2212  BDataType,
2213  BDataType,
2214  decltype(b_grid_desc_bpreshuffled),
2215  decltype(b_block_desc_bk0_n_bk1),
2218  3,
2219  BBlockTransferSrcScalarPerVector,
2220  BThreadTransferSrcResetCoordinateAfterRun,
2221  true>(b_grid_desc_bpreshuffled,
2222  make_multi_index(n_block_data_idx_on_grid,
2224  0,
2225  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2226  const BScaleType* p_b_scale_grid_up =
2227  p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
2228  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2229  p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
2230  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2231  auto b_scale_thread_copy_up =
2233  BScaleType,
2234  decltype(b_scale_grid_desc_bn_ak),
2235  decltype(b_scale_thread_desc),
2238  1,
2239  ScaleSliceSizeK,
2240  1,
2241  false>(
2242  b_scale_grid_desc_bn_ak,
2243  make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2244 
2245  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2246  a_grid_desc_ak0_m_ak1,
2247  a_block_desc_ak0_m_ak1,
2248  a_blockwise_copy,
2249  a_grid_buf,
2250  a_block_bufs,
2251  a_block_slice_copy_step,
2252  b_grid_desc_bpreshuffled,
2253  b_block_desc_bk0_n_bk1,
2254  b_blockwise_copy,
2255  b_blockwise_copy_up,
2256  b_grid_buf,
2257  b_grid_buf_up,
2258  b_block_bufs,
2259  b_block_slice_copy_step,
2260  c_scale_thread_desc,
2261  c_thread_buf,
2262  c_thread_buf_up,
2263  a_scale_grid_desc_am_ak,
2264  a_scale_thread_desc,
2265  a_scale_thread_copy,
2266  a_scale_grid_buf,
2267  a_scale_thread_slice_copy_step,
2268  b_scale_grid_desc_bn_ak,
2269  b_scale_thread_desc,
2270  b_scale_thread_copy,
2271  b_scale_thread_copy_up,
2272  b_scale_grid_buf,
2273  b_scale_grid_buf_up,
2274  b_scale_thread_slice_copy_step,
2275  num_k_block_main_loop);
2276  }
2277  else
2278  {
2279  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2280  a_grid_desc_ak0_m_ak1,
2281  a_block_desc_ak0_m_ak1,
2282  a_blockwise_copy,
2283  a_grid_buf,
2284  a_block_bufs,
2285  a_block_slice_copy_step,
2286  b_grid_desc_bpreshuffled,
2287  b_block_desc_bk0_n_bk1,
2288  b_blockwise_copy,
2289  b_grid_buf,
2290  b_block_bufs,
2291  b_block_slice_copy_step,
2292  c_scale_thread_desc,
2293  c_thread_buf,
2294  a_scale_grid_desc_am_ak,
2295  a_scale_thread_desc,
2296  a_scale_thread_copy,
2297  a_scale_grid_buf,
2298  a_scale_thread_slice_copy_step,
2299  b_scale_grid_desc_bn_ak,
2300  b_scale_thread_desc,
2301  b_scale_thread_copy,
2302  b_scale_grid_buf,
2303  b_scale_thread_slice_copy_step,
2304  num_k_block_main_loop);
2305  }
2306 
2307  // shuffle C and write out
2308  {
2309 
2310  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2311  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2312  "wrong!");
2313 
2314  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2315 
2316  // transposed XDL
2317  // TODO: hacky, fix it!
2318  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
2319  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2320 
2321  // TODO: hacky, fix it!
2322  // only used to get lengths
2323  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
2324  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2325 
2326  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
2327  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
2328  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
2329  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
2330  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
2331  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
2332  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
2333  constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
2334 
2335  static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
2336  static_assert(M0 * M1 * M2 == MPerBlock);
2337  static_assert(N4 == 4 || N4 == 8);
2338  const index_t m1 = get_warp_local_1d_id() / NWave;
2339  const index_t m2 = threadIdx.x % get_warp_size() % M2;
2340 
2341  float topk_weight;
2342  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2343  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2344  if constexpr(MulRoutedWeight)
2345  {
2346  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
2347  topk_weight = p_ds_grid[I0][m_pos];
2348  }
2349  static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
2350  static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
2351  constexpr index_t c_offset =
2352  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2353  make_tuple(m0, n0, n2 * N4 + n4));
2354  constexpr auto cidx = Number<c_offset>{};
2355  if constexpr(IsInputGemm) // gu fusion, elementwise
2356  {
2357  if constexpr(ActivationOperation == Activation::silu_and_mul)
2358  {
2359  float gate = c_thread_buf[cidx];
2360  float up = c_thread_buf_up[cidx];
2361  if constexpr(MulRoutedWeight)
2362  {
2363  gate = gate * topk_weight;
2364  up = up * topk_weight;
2365  }
2367  {
2368  gate *= 16;
2369  up *= 16;
2370  }
2372  c_thread_buf(cidx) = gate * up;
2373  }
2374  else if(ActivationOperation == Activation::gelu_and_mul)
2375  {
2376  float gate = c_thread_buf[cidx];
2377  float up = c_thread_buf_up[cidx];
2378  if constexpr(MulRoutedWeight)
2379  {
2380  gate = gate * topk_weight;
2381  up = up * topk_weight;
2382  }
2384  {
2385  gate *= 16;
2386  up *= 16;
2387  }
2389  c_thread_buf(cidx) = gate * up;
2390  }
2391  }
2392  else
2393  {
2394  if constexpr(MulRoutedWeight)
2395  {
2396  c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
2397  }
2398  }
2399 
2400  });
2401  });
2402  });
2403  });
2404 
2405  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2407 
2408  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2409  static_cast<CShuffleDataType*>(p_shared),
2410  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2411 
2412  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
2413  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2414  make_tuple(
2417  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2418  M1, // M1 = MWave
2419  M2)), // M2 = MPerXdl
2422  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2423  N1, // N1 = NWave
2424  N2, // N2 * N3 * N4 = NPerXdl
2425  N3,
2426  N4))),
2428  make_tuple(
2430 
2431  // calculate origin of thread output tensor on global memory
2432  // blockwise GEMM c matrix starting index
2433  const auto c_thread_mtx_on_block =
2434  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2435 
2436  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2437  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2438 
2439  const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
2443  make_tuple(Sequence<0>{}));
2444 
2445  const auto m_thread_data_on_block_idx =
2446  m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
2447  make_multi_index(m_thread_data_on_block));
2448 
2449  const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
2451  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
2453  make_tuple(Sequence<0>{}));
2454 
2455  const auto n_thread_data_on_block_idx =
2456  n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
2457  make_multi_index(n_thread_data_on_block));
2458 
2459  // shuffle: threadwise copy C from VGPR to LDS
2460  auto c_thread_copy_vgpr_to_lds =
2462  CShuffleDataType,
2463  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2464  decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2466  Sequence<CShuffleMXdlPerWavePerShuffle,
2467  CShuffleNXdlPerWavePerShuffle,
2468  I1,
2469  I1,
2470  I1,
2471  N2,
2472  I1,
2473  N4>,
2475  7,
2476  1,
2478  1,
2479  true>{
2480  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2481  make_multi_index(0,
2482  0,
2483  m_thread_data_on_block_idx[I1],
2484  n_thread_data_on_block_idx[I1],
2485  m_thread_data_on_block_idx[I2],
2486  n_thread_data_on_block_idx[I2],
2487  n_thread_data_on_block_idx[I3],
2488  n_thread_data_on_block_idx[I4]),
2490 
2491  using EDataType = CDataType;
2492 
2493  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2494  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2495 
2496  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2498  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2499 
2500  const auto ds_grid_buf = generate_tuple(
2501  [&](auto i) {
2502  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2503  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2504  },
2505  Number<NumDTensor>{});
2506 
2507  // tuple of reference to C/Ds tensor descriptors
2508  const auto c_ds_desc_refs = concat_tuple_of_reference(
2509  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2510  generate_tie([&](auto i) -> const auto& // return type should be reference
2511  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2512  Number<NumDTensor>{}));
2513 
2514  // tuple of reference to C/Ds tensor descriptors
2515  const auto c_ds_buf_refs = concat_tuple_of_reference(
2516  tie(c_shuffle_block_buf),
2517  generate_tie([&](auto i) -> const auto& // return type should be reference
2518  { return ds_grid_buf[i]; },
2519  Number<NumDTensor>{}));
2520 
2521  // tuple of starting index of C/Ds blockwise copy
2522  const auto idx_c_ds_block_begin =
2525  [&](auto) {
2526  return make_multi_index(block_m_id, 0, block_n_id, 0);
2527  // return make_multi_index(block_work_idx[I0], 0,
2528  // block_work_idx[I1], 0);
2529  },
2530  Number<NumDTensor>{}));
2531 
2532  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2533  c_grid_desc_mblock_mperblock_nblock_nperblock;
2534 
2535  using CDEBlockTransferCluster =
2536  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2537  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2538  constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
2539  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2541  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2543  decltype(c_ds_desc_refs),
2544  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2545  CElementwiseOperation,
2546  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2547  // support arbitray type
2548  Sequence<1,
2549  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2550  1,
2551  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2552  CDEBlockTransferCluster,
2553  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2554  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2555  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2556  3, // index_t SrcVectorDim,
2557  3, // index_t DstVectorDim,
2558  CDEShuffleBlockTransferScalarPerVectors,
2563  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2564  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2565  IndexType,
2566  1, // ScatterDim
2567  true, // OutputScatter: false, only use scatter weights
2568  scatter_weight_idx // ScatterWeightIdx: ascale
2569  >{c_ds_desc_refs,
2570  idx_c_ds_block_begin,
2571  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2572  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2573  c_element_op};
2574 
2575  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2576  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2577  // space filling curve for threadwise C in VGPR
2578  constexpr auto sfc_c_vgpr =
2581  Sequence<CShuffleMXdlPerWavePerShuffle,
2582  CShuffleNXdlPerWavePerShuffle,
2583  1,
2584  1,
2585  1,
2586  N2,
2587  1,
2588  N4>>{};
2589 
2590  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2591 
2592  // space filling curve for shuffled blockwise C/D/E
2593  constexpr auto sfc_cde_block =
2596  Sequence<1,
2597  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2598  1,
2599  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2600 
2601  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2602  constexpr auto EMThreads =
2603  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2604  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2605  constexpr auto ENThreads =
2606  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2607  static_for<0, num_access, 1>{}([&](auto access_id) {
2608  // make sure it's safe to write to LDS
2610  scatter_offsets; //= p_sorted_token_ids[c_token_pos];
2611 
2612  auto dstidx = sfc_cde_block.GetIndex(access_id);
2613  const index_t c_token_pos =
2614  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2615  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2616  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2617  index_t token_offset = fused_token & 0xffffff;
2618  if constexpr(IsInputGemm)
2619  {
2620  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2621  }
2622  scatter_offsets(m0) = token_offset * problem.N;
2623  });
2624 
2625  block_sync_lds();
2626 
2627  // each thread write its data from VGPR to LDS
2628  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2629  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2630  c_thread_buf,
2631  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2632  c_shuffle_block_buf);
2633 
2634  // make sure it's safe to read from LDS
2635  block_sync_lds();
2636 
2637  // each block copy its data from LDS to global
2638  cde_block_copy_lds_and_global.Run(
2639  c_ds_desc_refs,
2640  c_ds_buf_refs,
2641  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2642  tie(c_grid_buf),
2643  scatter_offsets);
2644 
2645  if constexpr(access_id < num_access - 1)
2646  {
2647  constexpr auto cde_lds_and_global_step =
2648  sfc_cde_block.GetForwardStep(access_id);
2649 
2650  // move on Ds
2651  static_for<0, NumDTensor, 1>{}([&](auto i) {
2652  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2653  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2654  });
2655 
2656  // move on E
2657  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2658  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2659  I0,
2660  cde_lds_and_global_step);
2661  }
2662  });
2663  }
2664  }
2665 };
2666 
2667 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
int64_t long_index_t
Definition: ck.hpp:300
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
Definition: gridwise_moe_gemm_blockscale.hpp:666
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm_blockscale.hpp:722
CDataType * p_c_grid
Definition: gridwise_moe_gemm_blockscale.hpp:728
const BScaleType * p_b_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:731
const index_t * p_max_token_id
Definition: gridwise_moe_gemm_blockscale.hpp:724
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm_blockscale.hpp:727
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:735
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm_blockscale.hpp:667
const ADataType * p_a_grid
Definition: gridwise_moe_gemm_blockscale.hpp:725
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:733
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm_blockscale.hpp:723
const BDataType * p_b_grid
Definition: gridwise_moe_gemm_blockscale.hpp:726
const AScaleType * p_a_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:730
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:734
Definition: gridwise_moe_gemm_blockscale.hpp:601
index_t K
Definition: gridwise_moe_gemm_blockscale.hpp:648
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm_blockscale.hpp:602
index_t TopK
Definition: gridwise_moe_gemm_blockscale.hpp:645
index_t NPadded
Definition: gridwise_moe_gemm_blockscale.hpp:655
index_t StrideB
Definition: gridwise_moe_gemm_blockscale.hpp:650
__host__ void Print() const
Definition: gridwise_moe_gemm_blockscale.hpp:633
index_t BK0
Definition: gridwise_moe_gemm_blockscale.hpp:659
index_t KRead
Definition: gridwise_moe_gemm_blockscale.hpp:656
index_t N
Definition: gridwise_moe_gemm_blockscale.hpp:647
index_t StrideC
Definition: gridwise_moe_gemm_blockscale.hpp:652
index_t KBatch
Definition: gridwise_moe_gemm_blockscale.hpp:653
index_t MBlock
Definition: gridwise_moe_gemm_blockscale.hpp:660
index_t KPadded
Definition: gridwise_moe_gemm_blockscale.hpp:657
index_t NumTokens
Definition: gridwise_moe_gemm_blockscale.hpp:644
index_t StrideA
Definition: gridwise_moe_gemm_blockscale.hpp:649
index_t AK0
Definition: gridwise_moe_gemm_blockscale.hpp:658
index_t M
Definition: gridwise_moe_gemm_blockscale.hpp:646
index_t MPadded
Definition: gridwise_moe_gemm_blockscale.hpp:654
index_t NBlock
Definition: gridwise_moe_gemm_blockscale.hpp:661
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm_blockscale.hpp:651
Definition: gridwise_moe_gemm_blockscale.hpp:739
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm_blockscale.hpp:740
index_t a_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:771
index_t b_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:772
Definition: gridwise_moe_gemm_blockscale.hpp:177
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:530
static constexpr index_t KPack
Definition: gridwise_moe_gemm_blockscale.hpp:202
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm_blockscale.hpp:893
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm_blockscale.hpp:417
static constexpr auto AK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:195
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:275
static constexpr auto BK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:196
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm_blockscale.hpp:236
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1925
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1141
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm_blockscale.hpp:775
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:266
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:291
static constexpr auto I3
Definition: gridwise_moe_gemm_blockscale.hpp:184
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:968
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:321
static constexpr index_t KLane
Definition: gridwise_moe_gemm_blockscale.hpp:215
remove_cvref_t< decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm_blockscale.hpp:942
float AScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:178
static constexpr auto AK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:193
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm_blockscale.hpp:217
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm_blockscale.hpp:335
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:1156
static constexpr auto I2
Definition: gridwise_moe_gemm_blockscale.hpp:183
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:238
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:271
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:280
static constexpr auto I4
Definition: gridwise_moe_gemm_blockscale.hpp:185
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm_blockscale.hpp:575
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:261
static constexpr auto I6
Definition: gridwise_moe_gemm_blockscale.hpp:187
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm_blockscale.hpp:190
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1148
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:524
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:252
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm_blockscale.hpp:234
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm_blockscale.hpp:900
float BScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:179
static constexpr auto I7
Definition: gridwise_moe_gemm_blockscale.hpp:188
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:310
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:587
static constexpr index_t NWave
Definition: gridwise_moe_gemm_blockscale.hpp:219
static constexpr index_t NLane
Definition: gridwise_moe_gemm_blockscale.hpp:218
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm_blockscale.hpp:199
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:303
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:315
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm_blockscale.hpp:944
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:297
static constexpr auto I0
Definition: gridwise_moe_gemm_blockscale.hpp:181
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:515
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:554
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm_blockscale.hpp:197
static constexpr index_t KGroup
Definition: gridwise_moe_gemm_blockscale.hpp:204
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:245
static constexpr auto I5
Definition: gridwise_moe_gemm_blockscale.hpp:186
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1177
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm_blockscale.hpp:223
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm_blockscale.hpp:427
static constexpr auto I1
Definition: gridwise_moe_gemm_blockscale.hpp:182
static constexpr auto BK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:194
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm_blockscale.hpp:221
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:285
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))> DsGridDesc_M_N
Definition: gridwise_moe_gemm_blockscale.hpp:598
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1810
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1804
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:440
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: amd_ck_fp8.hpp:36
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049