/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File
gridwise_moe_gemm_blockscale.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
16 
18 
19 #define DEBUG_LOG 0
20 
21 namespace ck {
22 
23 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24 // kernel function Blockers:
25 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26 // two lds chunks.
27 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28 // buffer when we declare __shared__ inside blkgemmpipe
29 
31 {
32  gelu_and_mul = 0,
33  silu_and_mul = 1
34 };
35 
36 template <typename GridwiseGemm,
37  bool HasMainKBlockLoop,
38  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
39  index_t MinimumOccupancy = 1,
40  TailNumber TailNum = TailNumber::Even>
41 __global__ void
42 #if CK_USE_LAUNCH_BOUNDS
43 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
44 #endif
45  // __attribute__((amdgpu_waves_per_eu(1, 1)))
46  kernel_moe_gemm(typename GridwiseGemm::Argument karg)
47 {
48 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
49  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
50  {
51  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52 
53  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54 
55  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56  karg.p_sorted_token_ids,
57  karg.p_sorted_expert_ids,
58  karg.p_max_token_id,
59  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
61  karg.p_ds_grid,
62  karg.p_c_grid,
63  karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
64  karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
65  p_shared,
66  karg,
67  karg.a_element_op,
68  karg.b_element_op,
69  karg.c_element_op);
70  }
71 #else
72  ignore = karg;
73 #endif // end of if (defined(__gfx9__))
74 }
75 
76 template <typename GridwiseGemm,
77  bool HasMainKBlockLoop,
78  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
79  index_t MinimumOccupancy = 1,
80  TailNumber TailNum = TailNumber::Even>
81 __global__ void
82 #if CK_USE_LAUNCH_BOUNDS
83 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
84 #endif
85  // __attribute__((amdgpu_waves_per_eu(1, 1)))
86  kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
87 {
88 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
89  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
90  {
91  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
92  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
93 
94  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
95 
96  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
97  karg.p_sorted_token_ids,
98  karg.p_sorted_expert_ids,
99  karg.p_max_token_id,
100  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
101  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
102  karg.p_ds_grid,
103  karg.p_c_grid,
104  karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
105  karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
106  p_shared,
107  p_shared1,
108  karg,
109  karg.a_element_op,
110  karg.b_element_op,
111  karg.c_element_op);
112  }
113 #else
114  ignore = karg;
115 #endif // end of if (defined(__gfx9__))
116 }
117 
118 template <typename ALayout,
119  typename BLayout,
120  typename DsLayout,
121  typename CLayout,
122  typename ADataType,
123  typename BDataType,
124  typename AccDataType,
125  typename CShuffleDataType,
126  typename DsDataType,
127  typename CDataType,
128  typename AElementwiseOperation,
129  typename BElementwiseOperation,
130  typename CElementwiseOperation,
132  index_t BlockSize,
133  index_t ScaleBlockM,
134  index_t ScaleBlockN,
135  index_t ScaleBlockK,
136  index_t MPerBlock,
137  index_t NPerBlock,
138  index_t KPerBlock,
139  index_t AK1Value,
140  index_t BK1Value,
141  index_t MPerXdl,
142  index_t NPerXdl,
143  index_t MXdlPerWave,
144  index_t NXdlPerWave,
145  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
146  typename ABlockTransferThreadClusterArrangeOrder,
147  typename ABlockTransferSrcAccessOrder,
148  index_t ABlockTransferSrcVectorDim,
149  index_t ABlockTransferSrcScalarPerVector,
150  index_t ABlockTransferDstScalarPerVector_AK1,
151  bool AThreadTransferSrcResetCoordinateAfterRun,
152  index_t ABlockLdsExtraM,
153  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
154  typename BBlockTransferThreadClusterArrangeOrder,
155  typename BBlockTransferSrcAccessOrder,
156  index_t BBlockTransferSrcVectorDim,
157  index_t BBlockTransferSrcScalarPerVector,
158  index_t BBlockTransferDstScalarPerVector_BK1,
159  bool BThreadTransferSrcResetCoordinateAfterRun,
160  index_t BBlockLdsExtraN,
161  index_t CShuffleMXdlPerWavePerShuffle,
162  index_t CShuffleNXdlPerWavePerShuffle,
163  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
164  typename CDEShuffleBlockTransferScalarPerVectors,
167  index_t ActivationOperation = 0,
168  bool NSwizzle = false,
169  bool IsInputGemm = true,
170  bool IsSplitK = false,
171  bool MulRoutedWeight = true,
172  typename IndexType = index_t,
173  typename ComputeTypeA = CDataType,
174  typename ComputeTypeB = ComputeTypeA,
175  typename LDSTypeA = ADataType,
176  typename LDSTypeB = BDataType>
178 {
179  using AScaleType = float;
180  using BScaleType = float;
181 
182  static constexpr auto I0 = Number<0>{};
183  static constexpr auto I1 = Number<1>{};
184  static constexpr auto I2 = Number<2>{};
185  static constexpr auto I3 = Number<3>{};
186  static constexpr auto I4 = Number<4>{};
187  static constexpr auto I5 = Number<5>{};
188  static constexpr auto I6 = Number<6>{};
189  static constexpr auto I7 = Number<7>{};
190 
192  CDEShuffleBlockTransferScalarPerVectors{}[I0];
193  // K1 should be Number<...>
194  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
195  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
196  static constexpr auto AK1Number = Number<AK1Value>{};
197  static constexpr auto BK1Number = Number<BK1Value>{};
198  static constexpr auto BlockSizeNumber = Number<BlockSize>{};
199 
200  static constexpr index_t NumDTensor = DsDataType::Size();
201 
203  static constexpr index_t KPack =
205  static constexpr index_t KGroup = []() {
207  // On gfx950, we have a mfma that required 32 f8 elements as input,
208  // splited into 2 groups of 16 f8 elements.
209  // the 2 groups is not contiguous in the B preshuffed layout.
210  // and we do not want it to be contiguous in the B preshuffled layout
211  // because a memory instruction can only read 16 f8 elements at a time.
212  return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
213  else
214  return 1;
215  }();
216  static constexpr index_t KLane =
218  static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
219  static constexpr index_t NLane = NPerXdl;
220  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
221  // static constexpr index_t NumTokens = 1;
222  static constexpr index_t SortedTileSize = MPerBlock;
223 
224  static constexpr auto MakeDsGridPointer()
225  {
226  return generate_tuple(
227  [&](auto i) {
228  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
229 
230  return static_cast<const DDataType*>(nullptr);
231  },
233  }
234 
235  using DsGridPointer = decltype(MakeDsGridPointer());
236 
238 
239  static constexpr index_t APackedSize = []() {
241  return 2;
242  else
243  return 1;
244  }();
245 
246  static constexpr index_t BPackedSize = []() {
248  return 2;
249  else
250  return 1;
251  }();
252 
253  __host__ static auto CalculateGridSize(index_t M, index_t N, index_t K, index_t KBatch)
254  {
255  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
256  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
257  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
258  const index_t gridy = NSwizzle ? 1 : mblock;
259  const index_t gridz = KBatch == 1 ? 1 : math::integer_divide_ceil(K, KPerBlock * KBatch);
260 
261  return std::make_tuple(gridx, gridy, gridz);
262  }
263 
264  __host__ __device__ static auto CalculateMPadded(index_t M)
265  {
266  return math::integer_least_multiple(M, MPerBlock);
267  }
268 
269  __host__ __device__ static auto CalculateNPadded(index_t N)
270  {
271  return math::integer_least_multiple(N, NPerBlock);
272  }
273 
274  __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
275  {
276  return math::integer_divide_ceil(N, NLane);
277  }
278  __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
279  {
281  }
282 
283  __host__ __device__ static auto CalculateKPadded(index_t K)
284  {
285  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
286  }
287 
288  __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
289  {
290  // auto K_t = K_Batch * KPerBlock;
291  // return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
292  return K_Batch == 1 ? K / AK1Value : K_Batch * KPerBlock / AK1Value;
293  }
294 
295  __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
296  {
297  // auto K_t = K_Batch * KPerBlock;
298  // return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
299  return K_Batch == 1 ? K / BK1Value : K_Batch * KPerBlock / BK1Value;
300  }
301 
302  __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
303  {
304  // auto K_t = K_Batch * KPerBlock;
305  // return (K + K_t - 1) / K_t * KPerBlock;
306  return K_Batch == 1 ? K : K_Batch * KPerBlock;
307  }
308 
309  __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
310  {
311  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
312  // auto K_t = K_Batch * KReadVec;
313  // return (K + K_t - 1) / K_t * KReadVec;
314  return K_Batch == 1 ? math::integer_divide_ceil(K, KReadVec) * KReadVec
315  : K_Batch * KPerBlock;
316  }
317 
318  __host__ __device__ static auto CalculateMBlock(index_t M)
319  {
320  return math::integer_divide_ceil(M, MPerBlock);
321  }
322 
323  __host__ __device__ static auto CalculateNBlock(index_t N)
324  {
325  return math::integer_divide_ceil(N, NPerBlock);
326  }
327 
328  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
329  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
330  {
331  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
332  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
333 
335  TileDesc_K0_MN_K1{},
341  }
342 
343  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
344  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
345  {
346  const auto a_grid_desc_mraw_kraw = [&]() {
347  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
348  {
349  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
350  }
351  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
352  {
353  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
354  }
355  }();
356 
358 
359  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
360  GemmSpec == GemmSpecialization::MNKPadding)
361  {
362  // pad both M and K
363  const auto a_grid_desc_m_k =
364  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
366  make_right_pad_transform(K, KPad - K)),
369 
370  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
371  a_grid_desc_m_k,
376 
377  return a_grid_desc_ak0_m_ak1;
378  }
379  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
380  GemmSpec == GemmSpecialization::MNPadding)
381  {
382  // pad M, but not K
383  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
384  a_grid_desc_mraw_kraw,
386  make_right_pad_transform(M, MPad - M)),
389 
390  return a_grid_desc_ak0_m_ak1;
391  }
392  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
393  GemmSpec == GemmSpecialization::NKPadding)
394  {
395  // pad K, but not M
396  const auto a_grid_desc_m_k = transform_tensor_descriptor(
397  a_grid_desc_mraw_kraw,
401 
402  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
403  a_grid_desc_m_k,
408 
409  return a_grid_desc_ak0_m_ak1;
410  }
411  else
412  {
413  // not pad M or K
414  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
415  a_grid_desc_mraw_kraw,
420  return a_grid_desc_ak0_m_ak1;
421  }
422  }
423 
424  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
425  {
426  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
427  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
428  constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack / KGroup>{};
430  make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
431  make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
432  }
433 
434  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
435  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
436  {
437  const auto b_grid_desc_nraw_kraw = [&]() {
439  {
440  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
441  }
443  {
444  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
445  }
446  }();
447 
449 
450  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
451  GemmSpec != GemmSpecialization::Default),
452  "pk_i4_t does not support padding");
453 
454  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
455  GemmSpec == GemmSpecialization::MNKPadding)
456  {
457  // pad both N and K
458  const auto b_grid_desc_n_k =
459  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
461  make_right_pad_transform(K, KPad - K)),
464 
465  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
466  b_grid_desc_n_k,
471 
472  return b_grid_desc_bk0_n_bk1;
473  }
474  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
475  GemmSpec == GemmSpecialization::MNPadding)
476  {
477  // pad N, but not K
478  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
479  b_grid_desc_nraw_kraw,
481  make_right_pad_transform(N, NPad - N)),
484 
485  return b_grid_desc_bk0_n_bk1;
486  }
487  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
488  GemmSpec == GemmSpecialization::MKPadding)
489  {
490  // pad K, but not N
491  const auto b_grid_desc_n_k = transform_tensor_descriptor(
492  b_grid_desc_nraw_kraw,
496 
497  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
498  b_grid_desc_n_k,
503 
504  return b_grid_desc_bk0_n_bk1;
505  }
506  else
507  {
508  // not pad N or K
509  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
510  b_grid_desc_nraw_kraw,
515 
516  return b_grid_desc_bk0_n_bk1;
517  }
518  }
519 
520  template <typename ABlockDesc_AK0_M_AK1>
521  __host__ __device__ static constexpr auto
522  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
523  {
524  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
525 
526  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
527  }
528 
529  template <typename BBlockDesc_BK0_N_BK1>
530  __host__ __device__ static constexpr auto
531  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
532  {
533  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
534  }
535 
536  template <typename ELayout>
537  __host__ __device__ static auto MakeCGridDescriptor_M_N(
538  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
539  {
540  const auto c_grid_desc_mraw_nraw = [&]() {
542  {
543  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
544  }
546  {
547  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
548  }
549  }();
550 
551  // pad M and N
552  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
554  make_right_pad_transform(N, NPad - N)),
557  }
558 
559  template <typename DLayout>
560  __host__ __device__ static auto
562  {
563  const auto c_grid_desc_mraw_nraw = [&]() {
565  {
566  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
567  }
569  {
570  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
571  }
572  }();
573 
574  // pad M and N
575  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
577  make_right_pad_transform(N, NPad - N)),
580  }
581 
582  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
583  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
584  {
585  return generate_tuple(
586  [&](auto i) {
587  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
588  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
589  },
591  }
592 
593  template <typename DsGridDesc>
595  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
596  {
597  return generate_tuple(
598  [&](auto i) {
600  ds_grid_desc_m_n[i], MBlock, NBlock);
601  },
603  }
604 
605  using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
606 
607  struct Problem
608  {
609  __host__ __device__ Problem(index_t NumTokens_,
610  index_t TopK_,
611  index_t M_,
612  index_t N_,
613  index_t K_,
614  index_t StrideA_,
615  index_t StrideB_,
616  std::array<index_t, NumDTensor> StrideDs_,
617  index_t StrideC_,
618  index_t KBatch_)
619  : NumTokens{NumTokens_},
620  TopK{TopK_},
621  M{M_},
622  N{N_},
623  K{K_},
624  StrideA{StrideA_},
625  StrideB{StrideB_},
626  StrideDs{StrideDs_},
627  StrideC{StrideC_},
628  KBatch{KBatch_},
631  KRead{CalculateKRead(K_, KBatch_)},
632  KPadded{CalculateKPadded(K_, KBatch_)},
633  AK0{CalculateAK0Padded(K_, KBatch_)},
634  BK0{CalculateBK0Padded(K_, KBatch_)},
635  MBlock{CalculateMBlock(M_)},
637  {
638  }
639 
640  __host__ void Print() const
641  {
642  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
643  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
644  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
645  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
646  << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
647  << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
648  << "NBlock: " << NBlock << "}" << std::endl;
649  }
650 
658  std::array<index_t, NumDTensor> StrideDs;
669  };
670 
671  // Argument
673  {
674  __host__ Argument(const index_t* p_sorted_token_ids_,
675  const index_t* p_sorted_expert_ids_,
676  const index_t* p_max_token_id_,
677  const ADataType* p_a_grid_,
678  const BDataType* p_b_grid_,
679  std::array<const void*, NumDTensor> p_ds_grid_,
680  CDataType* p_c_grid_,
681  index_t NumTokens_,
682  index_t TopK_,
683  index_t M_,
684  index_t N_,
685  index_t K_,
686  index_t StrideA_,
687  index_t StrideB_,
688  std::array<index_t, NumDTensor> StrideDs_,
689  index_t StrideC_,
690  const AScaleType* p_a_scale_grid_,
691  const BScaleType* p_b_scale_grid_,
692  index_t k_batch_,
693  AElementwiseOperation a_element_op_,
694  BElementwiseOperation b_element_op_,
695  CElementwiseOperation c_element_op_)
696  : Problem{NumTokens_,
697  TopK_,
698  M_,
699  N_,
700  K_,
701  StrideA_,
702  StrideB_,
703  StrideDs_,
704  StrideC_,
705  k_batch_},
706  p_sorted_token_ids{p_sorted_token_ids_},
707  p_sorted_expert_ids{p_sorted_expert_ids_},
708  p_max_token_id{p_max_token_id_},
709  p_a_grid{p_a_grid_},
710  p_b_grid{p_b_grid_},
711  p_ds_grid{},
712  p_c_grid{p_c_grid_},
713  p_a_scale_grid{p_a_scale_grid_},
714  p_b_scale_grid{p_b_scale_grid_},
715  a_element_op{a_element_op_},
716  b_element_op{b_element_op_},
717  c_element_op{c_element_op_}
718  {
719 
720  // populate pointer, desc for Ds
721  static_for<0, NumDTensor, 1>{}([&](auto i) {
722  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
723 
724  // D pointer
725  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
726  });
727  }
728 
732  const ADataType* p_a_grid;
733  const BDataType* p_b_grid;
735  CDataType* p_c_grid;
736 
739 
740  const AElementwiseOperation a_element_op;
741  const BElementwiseOperation b_element_op;
742  const CElementwiseOperation c_element_op;
743  };
744 
746  {
747  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
748  {
749  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
750  {
751  a_k_split_offset = k_id * karg.KRead / APackedSize;
753  }
754  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
755  {
756  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
758  }
759 
760  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
761  {
762  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
764  }
765  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
766  {
767  // KPack * NLane * KLane * K0 * N0
768  b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
769  bscale_k_split_offset = k_id * karg.KRead / ScaleBlockK;
770  }
771 
772  // if(k_id < karg.KBatch - 1)
773  // {
774  // karg.K = karg.KRead;
775  // }
776  // else
777  // {
778  // karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
779  // }
780  }
781 
786  };
787 
788  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
789  {
790  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
791  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
792  // A matrix in LDS memory, dst of blockwise copy
793  if constexpr(ABlockLdsExtraM)
794  {
798  }
799  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
800  // in some cases.
802  {
803  constexpr auto a_lds_block_desc =
806 
807  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
808  a_lds_block_desc,
814 
815  return a_lds_block_desc_permuted;
816  }
817  else // ColumnMajor A
818  {
819  // kfold and mpair dimension is not always required.
820  // more dimension in merge_transform increase the difficulty of generating immarg offset
821  // for compiler.
822  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
823  constexpr auto M1 = MPerBlock / M0;
824 
825  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
826  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
827  constexpr auto KThreadRead = WaveSize / MPerXdl;
828  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
829 
830  constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
831  ? 1
832  : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
833  constexpr auto KThreadReadPerm =
834  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
835  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
836  : KThreadRead;
837 
838  // 1<=mpair<=n0
839  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
840  ? 1
841  : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
842  ? M0
843  : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
844 
845  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
849  Number<kfold * M0 / mpair>{},
850  Number<mpair>{},
851  AK1Number));
852 
853  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
854  a_lds_block_desc,
855  make_tuple(
859  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
862  make_tuple(
864  make_tuple(
866 
867  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
868  a_lds_block_desc_permuted,
869  make_tuple(
877  Sequence<1>{},
878  Sequence<2>{},
879  Sequence<3>{},
880  Sequence<4>{},
881  Sequence<5>{}),
883  Sequence<2>{},
884  Sequence<0, 3>{},
885  Sequence<4, 5>{},
886  Sequence<6>{},
887  Sequence<7>{}));
888 
889  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
890  a_lds_block_desc_unmerged,
893  Number<KThreadWrite / kfold / KThreadReadPerm>{},
894  Number<kfold>{},
901 
902  return a_lds_block_desc_ak0_m_ak1;
903  }
904  }
905 
906  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
907  {
908  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
911  }
912 
914  {
915  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
916 
917  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
919  make_tuple(I1,
921  I1,
923 
924  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
925  }
926 
929  BlkGemmPipelineVer,
930  BlkGemmPipeSched,
931  BlockSize,
932  ADataType,
933  BDataType,
934  ComputeTypeA,
935  AccDataType,
942  ABlockTransferSrcScalarPerVector,
943  BBlockTransferSrcScalarPerVector,
944  MPerBlock,
945  NPerBlock,
946  KPerBlock,
947  ScaleBlockM,
948  ScaleBlockN,
949  ScaleBlockK,
950  MPerXdl,
951  NPerXdl,
952  MXdlPerWave,
953  NXdlPerWave,
954  KPack,
955  IsInputGemm && !IsSplitK > ())>;
956 
957  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
958  {
959  // LDS allocation for A and B: be careful of alignment
960  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
961  // lds max alignment
962  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
963 
964  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
965  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
966 
967  // LDS allocation for C shuffle in LDS
968  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
970 
971  constexpr auto c_block_size =
972  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
973 
974  return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize,
975  c_block_size * sizeof(CShuffleDataType));
976  }
977 
979 
980  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
981  __host__ static constexpr bool CheckValidity(const Argument& karg)
982  {
983  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
984  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
985  "Invalid tuning param!");
986 
987  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
992  {
993  if(!(karg.M % MPerBlock == 0))
994  {
995 #if DEBUG_LOG
996  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
997  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
998  << std::endl;
999 
1000 #endif // DEBUG_LOG
1001  return false;
1002  }
1003  }
1004 
1005  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1010  {
1011  if(!(karg.N % NPerBlock == 0))
1012  {
1013 #if DEBUG_LOG
1014  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1015  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1016  << std::endl;
1017 
1018 #endif // DEBUG_LOG
1019  return false;
1020  }
1021  }
1022 
1023  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1027  {
1028 
1029  auto K_t = karg.KBatch * KPerBlock;
1030  if(!(karg.K % K_t == 0))
1031  {
1032 #if DEBUG_LOG
1033  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1034  << karg.K << " " << __FILE__ << ":" << __LINE__
1035  << ", in function: " << __func__ << std::endl;
1036 
1037 #endif // DEBUG_LOG
1038  return false;
1039  }
1040  }
1041  else
1042  {
1043  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1044  auto K_t = karg.KBatch * KReadVec;
1045  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1046  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1047  {
1048  return false;
1049  }
1050  }
1051 
1053  {
1054  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1055  {
1056 #if DEBUG_LOG
1057  std::cout << "Arg K (" << karg.K
1058  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1059  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1060  << __LINE__ << ", in function: " << __func__ << std::endl;
1061 
1062 #endif // DEBUG_LOG
1063  return false;
1064  }
1065  }
1066  else
1067  {
1068  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1069  {
1070 #if DEBUG_LOG
1071  std::cout << "Arg M (" << karg.M
1072  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1073  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1074  << __LINE__ << ", in function: " << __func__ << std::endl;
1075 
1076 #endif // DEBUG_LOG
1077  return false;
1078  }
1079  }
1080 
1082  {
1083  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1084  {
1085 #if DEBUG_LOG
1086  std::cout << "Arg N (" << karg.N
1087  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1088  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1089  << __LINE__ << ", in function: " << __func__ << std::endl;
1090 
1091 #endif // DEBUG_LOG
1092  return false;
1093  }
1094  }
1095  else
1096  {
1097  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1098  {
1099 #if DEBUG_LOG
1100  std::cout << "Arg K (" << karg.K
1101  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1102  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1103  << __LINE__ << ", in function: " << __func__ << std::endl;
1104 
1105 #endif // DEBUG_LOG
1106  return false;
1107  }
1108  }
1109 
1111  {
1113  {
1114 #if DEBUG_LOG
1115  std::cout << "Arg N (" << karg.N
1116  << ") value is not a multiple of "
1117  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1118  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1119  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1120 
1121 #endif // DEBUG_LOG
1122  return false;
1123  }
1124  }
1125  else
1126  {
1128  {
1129 #if DEBUG_LOG
1130  std::cout << "Arg M (" << karg.M
1131  << ") value is not a multiple of "
1132  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1133  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1134  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1135 
1136 #endif // DEBUG_LOG
1137  return false;
1138  }
1139  }
1140 
1141  // check gridwise gemm pipeline
1142 #if 0
1143  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1144 
1145  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1146  {
1147  return false;
1148  }
1149 #endif
1150  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1151  return true;
1152  }
1153 
1154  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1155  {
1156  const index_t num_loop = K / KPerBlock;
1157 
1158  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1159  }
1160 
1161  __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1162  {
1163  const index_t num_loop = K / KPerBlock;
1164 
1165  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1166  }
1167 
1168  template <typename CGridDesc>
1170  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1171  {
1172  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1173  c_grid_desc_m_n,
1178 
1179  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1180  }
1181 
1182  // return block_id to C matrix tile idx (m0, n0) mapping
1183  // if arch = gfx942
1184  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1185  // NPerBlock>;
1186 
1187  template <bool HasMainKBlockLoop,
1188  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1189  TailNumber TailNum = TailNumber::Odd>
1190  __device__ static void Run(const index_t* p_sorted_token_ids,
1191  const index_t* p_sorted_expert_ids,
1192  const index_t* p_max_token_id,
1193  const ADataType* p_a_grid,
1194  const BDataType* p_b_grid,
1195  DsGridPointer& p_ds_grid,
1196  CDataType* p_c_grid,
1197  const AScaleType* p_a_scale_grid,
1198  const BScaleType* p_b_scale_grid,
1199  void* p_shared,
1200  const Problem& problem,
1201  AElementwiseOperation a_element_op,
1202  BElementwiseOperation b_element_op,
1203  CElementwiseOperation c_element_op)
1204  {
1205  ignore = b_element_op;
1206  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1));
1207  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1208  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1209  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1210  problem.MPadded,
1211  problem.K,
1212  problem.KPadded,
1213  problem.StrideA,
1214  problem.AK0);
1215  const auto b_grid_desc_bpreshuffled =
1216  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1217  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1218  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1219  problem.MPadded,
1220  problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
1221  problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
1222  problem.StrideC);
1223 
1224  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1225  make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
1226  : problem.NumTokens * problem.TopK,
1227  ScaleBlockM),
1228  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1229  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1230  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1231  make_tuple(math::integer_divide_ceil(problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
1232  ScaleBlockN),
1233  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1234  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1235 
1236  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1238  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1239  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1240  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1241  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1242  if(expert_block_id * MPerBlock >= max_token_id)
1243  return;
1244  const index_t expert_id =
1245  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1246  const auto block_mn = [&]() -> std::pair<int, int> {
1247  if constexpr(NSwizzle)
1248  {
1249  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1250  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1251  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1252  const index_t expert_swizzle =
1253  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1254  const index_t bid_new = blockIdx.x - prefix_block;
1255  const index_t nid = __builtin_amdgcn_readfirstlane(
1256  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1257  const index_t mid =
1258  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1259  return {nid, mid};
1260  }
1261  else
1262  {
1263  return {blockIdx.x, blockIdx.y};
1264  }
1265  }();
1266  const index_t block_n_id = block_mn.first;
1267  const index_t block_m_id = block_mn.second;
1268  const index_t token0 =
1269  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1270 
1271  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1272  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1273  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1274  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1275  constexpr auto AKThreads = AK0Threads * AK1Threads;
1276  constexpr auto AMRepeats = MPerBlock / AMThreads;
1277  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1278 
1279  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1280  return;
1282  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1283  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1284  index_t token_offset = fused_token & 0xffffff;
1285  if constexpr(!IsInputGemm)
1286  {
1287  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1288  }
1289  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1290  });
1291  const index_t expert_stride =
1292  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1293  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1294  math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
1295  math::integer_divide_ceil(problem.K, ScaleBlockK));
1296 
1297  // N0, K0, Blocksize*KPack
1298  const index_t n_block_data_idx_on_grid =
1299  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1300 
1301  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1302  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1303  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1304  p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1305  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1306 
1307  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1308  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1309  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1310  p_b_scale_grid + expert_id * expert_scale_stride,
1311  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1312 
1313  // A matrix in LDS memory, dst of blockwise copy
1314  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1315 
1316  // B matrix in LDS memory, dst of blockwise copy
1317  // dummy
1318  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1319  // A matrix blockwise copy
1320  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1322  AElementwiseOperation,
1326  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1327  ABlockTransferThreadClusterArrangeOrder,
1328  ADataType,
1329  LDSTypeA,
1330  decltype(a_grid_desc_ak0_m_ak1),
1331  decltype(a_block_desc_ak0_m_ak1),
1332  ABlockTransferSrcAccessOrder,
1334  ABlockTransferSrcVectorDim,
1335  2,
1336  ABlockTransferSrcScalarPerVector,
1337  ABlockTransferDstScalarPerVector_AK1,
1338  1,
1339  1,
1340  AThreadTransferSrcResetCoordinateAfterRun,
1341  true,
1342  IndexType,
1343  1,
1344  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1345  make_multi_index(0, 0, 0),
1346  a_element_op,
1347  a_block_desc_ak0_m_ak1,
1348  make_multi_index(0, 0, 0),
1350  gather_offsets);
1351 
1352  // Thread-wise copy
1353  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1354  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1355  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1356 
1357  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1358  BDataType,
1359  BDataType,
1360  decltype(b_grid_desc_bpreshuffled),
1361  decltype(b_block_desc_bk0_n_bk1),
1364  3,
1365  BBlockTransferSrcScalarPerVector,
1366  BThreadTransferSrcResetCoordinateAfterRun,
1367  true>(b_grid_desc_bpreshuffled,
1368  make_multi_index(n_block_data_idx_on_grid,
1370  0,
1371  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1372 
1373  // LDS allocation for A and B: be careful of alignment
1374  // Cast after lds
1375  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1376  static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1377 
1378  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1379  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1380 
1381  // Blockwise GEMM pipeline
1382  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1383  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1384  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1385  decltype(c_thread_buf) c_thread_buf_up;
1386 
1387  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1388  problem.KBatch == 1
1389  ? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1390  KPerBlock
1391  : problem.KBatch);
1392  constexpr index_t ScaleSliceSizeM = MXdlPerWave;
1393  constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
1394  constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
1395 
1396  // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
1397  // ScaleSliceSizeK is first dimension in C scale for packed math
1398  constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
1400 
1401  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1402  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1403  constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
1404  auto a_thread_offset = get_thread_local_1d_id() % MPerXdl +
1405  (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl;
1406 
1407  constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1409 
1410  constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
1412 
1413  // get each thread's offset in the scale tensor
1414  // A scale
1415  const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
1416 
1417  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
1418  return;
1419  StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
1420  static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
1421  const index_t fused_token =
1422  p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
1423  index_t token_offset = fused_token & 0xffffff;
1424  if constexpr(!IsInputGemm)
1425  {
1426  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1427  }
1428  scale_gather_offsets(m0) =
1429  token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK);
1430  });
1431 
1432  auto a_scale_thread_copy =
1434  AScaleType,
1435  decltype(a_scale_grid_desc_am_ak),
1436  decltype(a_scale_thread_desc),
1439  1,
1440  ScaleSliceSizeK,
1441  1,
1442  false,
1443  MXdlPerWave>(
1444  a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
1445 
1446  auto b_scale_thread_copy =
1448  BScaleType,
1449  decltype(b_scale_grid_desc_bn_ak),
1450  decltype(b_scale_thread_desc),
1453  1,
1454  ScaleSliceSizeK,
1455  1,
1456  false>(
1457  b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1458 
1459  // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
1460  constexpr auto a_scale_thread_slice_copy_step =
1461  make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
1462  constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
1463 
1464  constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
1465  if constexpr(IsInputGemm && !IsSplitK)
1466  {
1467  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1468  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1469  p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1470  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1471  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1472  BDataType,
1473  BDataType,
1474  decltype(b_grid_desc_bpreshuffled),
1475  decltype(b_block_desc_bk0_n_bk1),
1478  3,
1479  BBlockTransferSrcScalarPerVector,
1480  BThreadTransferSrcResetCoordinateAfterRun,
1481  true>(b_grid_desc_bpreshuffled,
1482  make_multi_index(n_block_data_idx_on_grid,
1484  0,
1485  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1486  const BScaleType* p_b_scale_grid_up =
1487  p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
1488  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1489  p_b_scale_grid_up + expert_id * expert_scale_stride,
1490  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1491  auto b_scale_thread_copy_up =
1493  BScaleType,
1494  decltype(b_scale_grid_desc_bn_ak),
1495  decltype(b_scale_thread_desc),
1498  1,
1499  ScaleSliceSizeK,
1500  1,
1501  false>(
1502  b_scale_grid_desc_bn_ak,
1503  make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1504 
1505  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1506  a_grid_desc_ak0_m_ak1,
1507  a_block_desc_ak0_m_ak1,
1508  a_blockwise_copy,
1509  a_grid_buf,
1510  a_block_buf,
1511  a_block_slice_copy_step,
1512 
1513  b_grid_desc_bpreshuffled,
1514  b_block_desc_bk0_n_bk1,
1515  b_blockwise_copy,
1516  b_blockwise_copy_up,
1517  b_grid_buf,
1518  b_grid_buf_up,
1519  b_block_buf,
1520  b_block_slice_copy_step,
1521 
1522  c_scale_thread_desc,
1523  c_thread_buf,
1524  c_thread_buf_up,
1525 
1526  a_scale_grid_desc_am_ak,
1527  a_scale_thread_desc,
1528  a_scale_thread_copy,
1529  a_scale_grid_buf,
1530  a_scale_thread_slice_copy_step,
1531 
1532  b_scale_grid_desc_bn_ak,
1533  b_scale_thread_desc,
1534  b_scale_thread_copy,
1535  b_scale_thread_copy_up,
1536  b_scale_grid_buf,
1537  b_scale_grid_buf_up,
1538  b_scale_thread_slice_copy_step,
1539 
1540  num_k_block_main_loop);
1541  }
1542  else
1543  {
1544  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1545  a_grid_desc_ak0_m_ak1,
1546  a_block_desc_ak0_m_ak1,
1547  a_blockwise_copy,
1548  a_grid_buf,
1549  a_block_buf,
1550  a_block_slice_copy_step,
1551 
1552  b_grid_desc_bpreshuffled,
1553  b_block_desc_bk0_n_bk1,
1554  b_blockwise_copy,
1555  b_grid_buf,
1556  b_block_buf,
1557  b_block_slice_copy_step,
1558 
1559  c_scale_thread_desc,
1560  c_thread_buf,
1561 
1562  a_scale_grid_desc_am_ak,
1563  a_scale_thread_desc,
1564  a_scale_thread_copy,
1565  a_scale_grid_buf,
1566  a_scale_thread_slice_copy_step,
1567 
1568  b_scale_grid_desc_bn_ak,
1569  b_scale_thread_desc,
1570  b_scale_thread_copy,
1571  b_scale_grid_buf,
1572  b_scale_thread_slice_copy_step,
1573 
1574  num_k_block_main_loop);
1575  }
1576 
1577  // shuffle C and write out
1578  {
1579  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1580  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1581  "wrong!");
1582 
1583  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1584 
1585  // transposed XDL
1586  // TODO: hacky, fix it!
1587  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1588  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1589 
1590  // TODO: hacky, fix it!
1591  // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
1592  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1593  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1594 
1595  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
1596  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
1597  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
1598  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
1599  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
1600  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
1601  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
1602  constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
1603 
1604  static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
1605  static_assert(M0 * M1 * M2 == MPerBlock);
1606  static_assert(N4 == 4 || N4 == 8);
1607  const index_t m1 = get_warp_local_1d_id() / NWave;
1608  const index_t m2 = threadIdx.x % get_warp_size() % M2;
1609 
1610  float topk_weight;
1611  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1612  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1613  if constexpr(MulRoutedWeight)
1614  {
1615  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
1616  topk_weight = p_ds_grid[I0][m_pos];
1617  }
1618  static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
1619  static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
1620  constexpr index_t c_offset =
1621  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1622  make_tuple(m0, n0, n2 * N4 + n4));
1623  constexpr auto cidx = Number<c_offset>{};
1624  if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise
1625  {
1626  if constexpr(ActivationOperation == Activation::silu_and_mul)
1627  {
1628  float gate = c_thread_buf[cidx];
1629  float up = c_thread_buf_up[cidx];
1630  if constexpr(MulRoutedWeight)
1631  {
1632  gate = gate * topk_weight;
1633  up = up * topk_weight;
1634  }
1636  {
1637  gate *= 16;
1638  up *= 16;
1639  }
1641  c_thread_buf(cidx) = gate * up;
1642  }
1643  else if(ActivationOperation == Activation::gelu_and_mul)
1644  {
1645  float gate = c_thread_buf[cidx];
1646  float up = c_thread_buf_up[cidx];
1647  if constexpr(MulRoutedWeight)
1648  {
1649  gate = gate * topk_weight;
1650  up = up * topk_weight;
1651  }
1653  {
1654  gate *= 16;
1655  up *= 16;
1656  }
1658  c_thread_buf(cidx) = gate * up;
1659  }
1660  }
1661  else
1662  {
1663  if constexpr(MulRoutedWeight)
1664  {
1665  c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
1666  }
1667  }
1668  });
1669  });
1670  });
1671  });
1672 
1673  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1675 
1676  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1677  static_cast<CShuffleDataType*>(p_shared),
1678  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1679 
1680  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
1681  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1682  make_tuple(
1685  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1686  M1, // M1 = MWave
1687  M2)), // M2 = MPerXdl
1690  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1691  N1, // N1 = NWave
1692  N2, // N2 * N3 * N4 = NPerXdl
1693  N3,
1694  N4))),
1696  make_tuple(
1698 
1699  // calculate origin of thread output tensor on global memory
1700  // blockwise GEMM c matrix starting index
1701  const auto c_thread_mtx_on_block =
1702  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1703 
1704  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1705  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1706 
1707  const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1711  make_tuple(Sequence<0>{}));
1712 
1713  const auto m_thread_data_on_block_idx =
1714  m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1715  make_multi_index(m_thread_data_on_block));
1716 
1717  const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1719  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
1721  make_tuple(Sequence<0>{}));
1722 
1723  const auto n_thread_data_on_block_idx =
1724  n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1725  make_multi_index(n_thread_data_on_block));
1726 
1727  // shuffle: threadwise copy C from VGPR to LDS
1728  auto c_thread_copy_vgpr_to_lds =
1730  CShuffleDataType,
1731  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1732  decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1734  Sequence<CShuffleMXdlPerWavePerShuffle,
1735  CShuffleNXdlPerWavePerShuffle,
1736  I1,
1737  I1,
1738  I1,
1739  N2,
1740  I1,
1741  N4>,
1743  7,
1744  1,
1746  1,
1747  true>{
1748  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1749  make_multi_index(0,
1750  0,
1751  m_thread_data_on_block_idx[I1],
1752  n_thread_data_on_block_idx[I1],
1753  m_thread_data_on_block_idx[I2],
1754  n_thread_data_on_block_idx[I2],
1755  n_thread_data_on_block_idx[I3],
1756  n_thread_data_on_block_idx[I4]),
1758 
1759  using EDataType = CDataType;
1760 
1761  const auto ds_grid_desc_m_n =
1762  MakeDsGridDescriptor_M_N(problem.M,
1763  problem.MPadded,
1764  problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
1765  problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
1766  problem.StrideDs);
1767 
1768  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1770  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1771 
1772  const auto ds_grid_buf = generate_tuple(
1773  [&](auto i) {
1774  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
1775  const DDataType* ptr_ = p_ds_grid[i];
1776  // hack logic here to support different kind of strides. todo fix it.
1777  // ascale t, 1; bscale E, N, 1, move ptr to E
1778  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1779  ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
1780  },
1781  Number<NumDTensor>{});
1782 
1783  // tuple of reference to C/Ds tensor descriptors
1784  const auto c_ds_desc_refs = concat_tuple_of_reference(
1785  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1786  generate_tie([&](auto i) -> const auto& // return type should be reference
1787  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1788  Number<NumDTensor>{}));
1789 
1790  // tuple of reference to C/Ds tensor descriptors
1791  const auto c_ds_buf_refs = concat_tuple_of_reference(
1792  tie(c_shuffle_block_buf),
1793  generate_tie([&](auto i) -> const auto& // return type should be reference
1794  { return ds_grid_buf[i]; },
1795  Number<NumDTensor>{}));
1796 
1797  // tuple of starting index of C/Ds blockwise copy
1798  const auto idx_c_ds_block_begin =
1801  [&](auto) {
1802  return make_multi_index(block_m_id, 0, block_n_id, 0);
1803  // return make_multi_index(block_work_idx[I0], 0,
1804  // block_work_idx[I1], 0);
1805  },
1806  Number<NumDTensor>{}));
1807 
1808  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1809  c_grid_desc_mblock_mperblock_nblock_nperblock;
1810 
1811  using CDEBlockTransferCluster =
1812  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1813  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1814  constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
1815  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1817  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1819  decltype(c_ds_desc_refs),
1820  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1821  CElementwiseOperation,
1822  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1823  // support arbitray type
1824  Sequence<1,
1825  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1826  1,
1827  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1828  CDEBlockTransferCluster,
1829  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1830  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1831  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1832  3, // index_t SrcVectorDim,
1833  3, // index_t DstVectorDim,
1834  CDEShuffleBlockTransferScalarPerVectors,
1839  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1840  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1841  IndexType,
1842  1, // ScatterDim
1843  true, // OutputScatter: false, only use scatter weights
1844  scatter_weight_idx // ScatterWeightIdx: ascale
1845  >{c_ds_desc_refs,
1846  idx_c_ds_block_begin,
1847  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1848  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1849  c_element_op};
1850 
1851  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1852  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1853  // space filling curve for threadwise C in VGPR
1854  constexpr auto sfc_c_vgpr =
1857  Sequence<CShuffleMXdlPerWavePerShuffle,
1858  CShuffleNXdlPerWavePerShuffle,
1859  1,
1860  1,
1861  1,
1862  N2,
1863  1,
1864  N4>>{};
1865 
1866  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1867 
1868  // space filling curve for shuffled blockwise C/D/E
1869  constexpr auto sfc_cde_block =
1872  Sequence<1,
1873  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1874  1,
1875  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1876 
1877  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1878  constexpr auto EMThreads =
1879  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1880  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1881  constexpr auto ENThreads =
1882  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1883  static_for<0, num_access, 1>{}([&](auto access_id) {
1884  // make sure it's safe to write to LDS
1886 
1887  auto dstidx = sfc_cde_block.GetIndex(access_id);
1888  const index_t c_token_pos =
1889  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1890  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1891  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1892  index_t token_offset = fused_token & 0xffffff;
1893  if constexpr(IsInputGemm)
1894  {
1895  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1896  }
1897  scatter_offsets(m0) =
1898  token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1);
1899  });
1900 
1901  block_sync_lds();
1902 
1903  // each thread write its data from VGPR to LDS
1904  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1905  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1906  c_thread_buf,
1907  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1908  c_shuffle_block_buf);
1909 
1910  // make sure it's safe to read from LDS
1911  block_sync_lds();
1912 
1913  // each block copy its data from LDS to global
1914  cde_block_copy_lds_and_global.Run(
1915  c_ds_desc_refs,
1916  c_ds_buf_refs,
1917  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1918  tie(c_grid_buf),
1919  scatter_offsets);
1920 
1921  if constexpr(access_id < num_access - 1)
1922  {
1923  constexpr auto cde_lds_and_global_step =
1924  sfc_cde_block.GetForwardStep(access_id);
1925 
1926  // move on Ds
1927  static_for<0, NumDTensor, 1>{}([&](auto i) {
1928  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1929  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1930  });
1931 
1932  // move on E
1933  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1934  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1935  I0,
1936  cde_lds_and_global_step);
1937  }
1938  });
1939  }
1940  }
1941 
1942  template <bool HasMainKBlockLoop,
1943  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1944  TailNumber TailNum = TailNumber::Odd>
1945  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1946  const index_t* p_sorted_expert_ids,
1947  const index_t* p_max_token_id,
1948  const ADataType* p_a_grid,
1949  const BDataType* p_b_grid,
1950  DsGridPointer& p_ds_grid,
1951  CDataType* p_c_grid,
1952  const AScaleType* p_a_scale_grid,
1953  const BScaleType* p_b_scale_grid,
1954  void* p_shared,
1955  void* p_shared1,
1956  const Problem& problem,
1957  AElementwiseOperation a_element_op,
1958  BElementwiseOperation b_element_op,
1959  CElementwiseOperation c_element_op)
1960  {
1961  ignore = b_element_op;
1962  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1963  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1964  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1965  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1966  problem.MPadded,
1967  problem.K,
1968  problem.KPadded,
1969  problem.StrideA,
1970  problem.AK0);
1971  const auto b_grid_desc_bpreshuffled =
1972  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1973  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1974  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1975  problem.MPadded,
1976  problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
1977  problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
1978  problem.StrideC);
1979 
1980  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1981  make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
1982  : problem.NumTokens * problem.TopK,
1983  ScaleBlockM),
1984  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1985  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1986  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1987  make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1988  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1989  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1990  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1992  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1993  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1994  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1995  if(expert_block_id * MPerBlock >= max_token_id)
1996  return;
1997  const index_t expert_id =
1998  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1999  const auto block_mn = [&]() -> std::pair<int, int> {
2000  if constexpr(NSwizzle)
2001  {
2002  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2003  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2004  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2005  const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
2006  const index_t bid_new = blockIdx.x - prefix_block;
2007  const index_t nid = __builtin_amdgcn_readfirstlane(
2008  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2009  const index_t mid =
2010  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2011  return {nid, mid};
2012  }
2013  else
2014  {
2015  return {blockIdx.x, blockIdx.y};
2016  }
2017  }();
2018  const index_t block_n_id = block_mn.first;
2019  const index_t block_m_id = block_mn.second;
2020 
2021  const index_t token0 =
2022  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2023 
2024  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2025  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2026  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2027  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2028  constexpr auto AKThreads = AK0Threads * AK1Threads;
2029  constexpr auto AMRepeats = MPerBlock / AMThreads;
2030  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2031 
2032  if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
2033  token0 >= problem.NumTokens)
2034  return;
2036  gather_offsets; //= p_sorted_token_ids[token_pos];
2037  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2038  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2039  index_t token_offset = fused_token & 0xffffff;
2040  if constexpr(!IsInputGemm)
2041  {
2042  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2043  }
2044  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2045  });
2046  const index_t expert_stride =
2047  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2048  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2049  math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
2050  math::integer_divide_ceil(problem.K, ScaleBlockK));
2051  // N0, K0, Blocksize*KPack
2052  const index_t n_block_data_idx_on_grid =
2053  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2054 
2055  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2056  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2057  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2058  p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2059  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2060 
2061  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2062  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2063  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2064  p_b_scale_grid + expert_id * expert_scale_stride,
2065  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2066 
2067  // A matrix in LDS memory, dst of blockwise copy
2068  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2069 
2070  // B matrix in LDS memory, dst of blockwise copy
2071  // dummy
2072  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2073  // A matrix blockwise copy
2074  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2076  AElementwiseOperation,
2080  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2081  ABlockTransferThreadClusterArrangeOrder,
2082  ADataType,
2083  LDSTypeA,
2084  decltype(a_grid_desc_ak0_m_ak1),
2085  decltype(a_block_desc_ak0_m_ak1),
2086  ABlockTransferSrcAccessOrder,
2088  ABlockTransferSrcVectorDim,
2089  2,
2090  ABlockTransferSrcScalarPerVector,
2091  ABlockTransferDstScalarPerVector_AK1,
2092  1,
2093  1,
2094  AThreadTransferSrcResetCoordinateAfterRun,
2095  true,
2096  IndexType,
2097  1,
2098  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2099  make_multi_index(0, 0, 0),
2100  a_element_op,
2101  a_block_desc_ak0_m_ak1,
2102  make_multi_index(0, 0, 0),
2104  gather_offsets);
2105 
2106  // Thread-wise copy
2107  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2108  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2109  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2110  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2111  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2112  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2113 
2114  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
2115  BDataType,
2116  BDataType,
2117  decltype(b_grid_desc_bpreshuffled),
2118  decltype(b_block_desc_bk0_n_bk1),
2121  3,
2122  BBlockTransferSrcScalarPerVector,
2123  BThreadTransferSrcResetCoordinateAfterRun,
2124  true>(b_grid_desc_bpreshuffled,
2125  make_multi_index(n_block_data_idx_on_grid,
2127  0,
2128  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2129 
2130  // LDS allocation for A and B: be careful of alignment
2131  // Cast after lds
2132  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2133  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2134  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2135  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2136  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2137 
2138  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2139  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
2140 
2141  // Blockwise GEMM pipeline
2142  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2143  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2144  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2145  decltype(c_thread_buf) c_thread_buf_up;
2146 
2147  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2148  problem.KBatch == 1
2149  ? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2150  KPerBlock
2151  : problem.KBatch);
2152 
2153  // scale
2154  constexpr index_t ScaleSliceSizeM = MXdlPerWave;
2155  constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
2156  constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
2157 
2158  // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
2159  // ScaleSliceSizeK is first dimension in C scale for packed math
2160  constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
2162 
2163  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
2164  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
2165  constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
2166  auto a_thread_offset = get_thread_local_1d_id() % MPerXdl +
2167  (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl;
2168 
2169  constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
2171 
2172  constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
2174 
2175  // get each thread's offset in the scale tensor
2176  // A scale
2177  const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
2178 
2179  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2180  return;
2181  StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
2182  static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
2183  const index_t fused_token =
2184  p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
2185  index_t token_offset = fused_token & 0xffffff;
2186  if constexpr(!IsInputGemm)
2187  {
2188  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2189  }
2190  scale_gather_offsets(m0) = static_cast<IndexType>(token_offset) *
2191  math::integer_divide_ceil(problem.K, ScaleBlockK);
2192  });
2193 
2194  auto a_scale_thread_copy =
2196  AScaleType,
2197  decltype(a_scale_grid_desc_am_ak),
2198  decltype(a_scale_thread_desc),
2201  1,
2202  ScaleSliceSizeK,
2203  1,
2204  false,
2205  MXdlPerWave>(
2206  a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
2207 
2208  auto b_scale_thread_copy =
2210  BScaleType,
2211  decltype(b_scale_grid_desc_bn_ak),
2212  decltype(b_scale_thread_desc),
2215  1,
2216  ScaleSliceSizeK,
2217  1,
2218  false>(
2219  b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2220 
2221  // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
2222  constexpr auto a_scale_thread_slice_copy_step =
2223  make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
2224  constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
2225 
2226  constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
2227  if constexpr(IsInputGemm && !IsSplitK)
2228  {
2229  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2230  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2231  p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2232  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2233  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2234  BDataType,
2235  BDataType,
2236  decltype(b_grid_desc_bpreshuffled),
2237  decltype(b_block_desc_bk0_n_bk1),
2240  3,
2241  BBlockTransferSrcScalarPerVector,
2242  BThreadTransferSrcResetCoordinateAfterRun,
2243  true>(b_grid_desc_bpreshuffled,
2244  make_multi_index(n_block_data_idx_on_grid,
2246  0,
2247  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2248  const BScaleType* p_b_scale_grid_up =
2249  p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
2250  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2251  p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
2252  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2253  auto b_scale_thread_copy_up =
2255  BScaleType,
2256  decltype(b_scale_grid_desc_bn_ak),
2257  decltype(b_scale_thread_desc),
2260  1,
2261  ScaleSliceSizeK,
2262  1,
2263  false>(
2264  b_scale_grid_desc_bn_ak,
2265  make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2266 
2267  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2268  a_grid_desc_ak0_m_ak1,
2269  a_block_desc_ak0_m_ak1,
2270  a_blockwise_copy,
2271  a_grid_buf,
2272  a_block_bufs,
2273  a_block_slice_copy_step,
2274  b_grid_desc_bpreshuffled,
2275  b_block_desc_bk0_n_bk1,
2276  b_blockwise_copy,
2277  b_blockwise_copy_up,
2278  b_grid_buf,
2279  b_grid_buf_up,
2280  b_block_bufs,
2281  b_block_slice_copy_step,
2282  c_scale_thread_desc,
2283  c_thread_buf,
2284  c_thread_buf_up,
2285  a_scale_grid_desc_am_ak,
2286  a_scale_thread_desc,
2287  a_scale_thread_copy,
2288  a_scale_grid_buf,
2289  a_scale_thread_slice_copy_step,
2290  b_scale_grid_desc_bn_ak,
2291  b_scale_thread_desc,
2292  b_scale_thread_copy,
2293  b_scale_thread_copy_up,
2294  b_scale_grid_buf,
2295  b_scale_grid_buf_up,
2296  b_scale_thread_slice_copy_step,
2297  num_k_block_main_loop);
2298  }
2299  else
2300  {
2301  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2302  a_grid_desc_ak0_m_ak1,
2303  a_block_desc_ak0_m_ak1,
2304  a_blockwise_copy,
2305  a_grid_buf,
2306  a_block_bufs,
2307  a_block_slice_copy_step,
2308  b_grid_desc_bpreshuffled,
2309  b_block_desc_bk0_n_bk1,
2310  b_blockwise_copy,
2311  b_grid_buf,
2312  b_block_bufs,
2313  b_block_slice_copy_step,
2314  c_scale_thread_desc,
2315  c_thread_buf,
2316  a_scale_grid_desc_am_ak,
2317  a_scale_thread_desc,
2318  a_scale_thread_copy,
2319  a_scale_grid_buf,
2320  a_scale_thread_slice_copy_step,
2321  b_scale_grid_desc_bn_ak,
2322  b_scale_thread_desc,
2323  b_scale_thread_copy,
2324  b_scale_grid_buf,
2325  b_scale_thread_slice_copy_step,
2326  num_k_block_main_loop);
2327  }
2328 
2329  // shuffle C and write out
2330  {
2331 
2332  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2333  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2334  "wrong!");
2335 
2336  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2337 
2338  // transposed XDL
2339  // TODO: hacky, fix it!
2340  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
2341  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2342 
2343  // TODO: hacky, fix it!
2344  // only used to get lengths
2345  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
2346  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2347 
2348  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
2349  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
2350  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
2351  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
2352  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
2353  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
2354  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
2355  constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
2356 
2357  static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
2358  static_assert(M0 * M1 * M2 == MPerBlock);
2359  static_assert(N4 == 4 || N4 == 8);
2360  const index_t m1 = get_warp_local_1d_id() / NWave;
2361  const index_t m2 = threadIdx.x % get_warp_size() % M2;
2362 
2363  float topk_weight;
2364  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2365  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2366  if constexpr(MulRoutedWeight)
2367  {
2368  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
2369  topk_weight = p_ds_grid[I0][m_pos];
2370  }
2371  static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
2372  static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
2373  constexpr index_t c_offset =
2374  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2375  make_tuple(m0, n0, n2 * N4 + n4));
2376  constexpr auto cidx = Number<c_offset>{};
2377  if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise
2378  {
2379  if constexpr(ActivationOperation == Activation::silu_and_mul)
2380  {
2381  float gate = c_thread_buf[cidx];
2382  float up = c_thread_buf_up[cidx];
2383  if constexpr(MulRoutedWeight)
2384  {
2385  gate = gate * topk_weight;
2386  up = up * topk_weight;
2387  }
2389  {
2390  gate *= 16;
2391  up *= 16;
2392  }
2394  c_thread_buf(cidx) = gate * up;
2395  }
2396  else if(ActivationOperation == Activation::gelu_and_mul)
2397  {
2398  float gate = c_thread_buf[cidx];
2399  float up = c_thread_buf_up[cidx];
2400  if constexpr(MulRoutedWeight)
2401  {
2402  gate = gate * topk_weight;
2403  up = up * topk_weight;
2404  }
2406  {
2407  gate *= 16;
2408  up *= 16;
2409  }
2411  c_thread_buf(cidx) = gate * up;
2412  }
2413  }
2414  else
2415  {
2416  if constexpr(MulRoutedWeight)
2417  {
2418  c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
2419  }
2420  }
2421 
2422  });
2423  });
2424  });
2425  });
2426 
2427  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2429 
2430  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2431  static_cast<CShuffleDataType*>(p_shared),
2432  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2433 
2434  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
2435  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2436  make_tuple(
2439  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2440  M1, // M1 = MWave
2441  M2)), // M2 = MPerXdl
2444  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2445  N1, // N1 = NWave
2446  N2, // N2 * N3 * N4 = NPerXdl
2447  N3,
2448  N4))),
2450  make_tuple(
2452 
2453  // calculate origin of thread output tensor on global memory
2454  // blockwise GEMM c matrix starting index
2455  const auto c_thread_mtx_on_block =
2456  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2457 
2458  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2459  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2460 
2461  const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
2465  make_tuple(Sequence<0>{}));
2466 
2467  const auto m_thread_data_on_block_idx =
2468  m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
2469  make_multi_index(m_thread_data_on_block));
2470 
2471  const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
2473  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
2475  make_tuple(Sequence<0>{}));
2476 
2477  const auto n_thread_data_on_block_idx =
2478  n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
2479  make_multi_index(n_thread_data_on_block));
2480 
2481  // shuffle: threadwise copy C from VGPR to LDS
2482  auto c_thread_copy_vgpr_to_lds =
2484  CShuffleDataType,
2485  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2486  decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2488  Sequence<CShuffleMXdlPerWavePerShuffle,
2489  CShuffleNXdlPerWavePerShuffle,
2490  I1,
2491  I1,
2492  I1,
2493  N2,
2494  I1,
2495  N4>,
2497  7,
2498  1,
2500  1,
2501  true>{
2502  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2503  make_multi_index(0,
2504  0,
2505  m_thread_data_on_block_idx[I1],
2506  n_thread_data_on_block_idx[I1],
2507  m_thread_data_on_block_idx[I2],
2508  n_thread_data_on_block_idx[I2],
2509  n_thread_data_on_block_idx[I3],
2510  n_thread_data_on_block_idx[I4]),
2512 
2513  using EDataType = CDataType;
2514 
2515  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2516  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2517 
2518  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2520  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2521 
2522  const auto ds_grid_buf = generate_tuple(
2523  [&](auto i) {
2524  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2525  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2526  },
2527  Number<NumDTensor>{});
2528 
2529  // tuple of reference to C/Ds tensor descriptors
2530  const auto c_ds_desc_refs = concat_tuple_of_reference(
2531  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2532  generate_tie([&](auto i) -> const auto& // return type should be reference
2533  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2534  Number<NumDTensor>{}));
2535 
2536  // tuple of reference to C/Ds tensor descriptors
2537  const auto c_ds_buf_refs = concat_tuple_of_reference(
2538  tie(c_shuffle_block_buf),
2539  generate_tie([&](auto i) -> const auto& // return type should be reference
2540  { return ds_grid_buf[i]; },
2541  Number<NumDTensor>{}));
2542 
2543  // tuple of starting index of C/Ds blockwise copy
2544  const auto idx_c_ds_block_begin =
2547  [&](auto) {
2548  return make_multi_index(block_m_id, 0, block_n_id, 0);
2549  // return make_multi_index(block_work_idx[I0], 0,
2550  // block_work_idx[I1], 0);
2551  },
2552  Number<NumDTensor>{}));
2553 
2554  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2555  c_grid_desc_mblock_mperblock_nblock_nperblock;
2556 
2557  using CDEBlockTransferCluster =
2558  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2559  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2560  constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
2561  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2563  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2565  decltype(c_ds_desc_refs),
2566  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2567  CElementwiseOperation,
2568  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2569  // support arbitray type
2570  Sequence<1,
2571  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2572  1,
2573  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2574  CDEBlockTransferCluster,
2575  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2576  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2577  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2578  3, // index_t SrcVectorDim,
2579  3, // index_t DstVectorDim,
2580  CDEShuffleBlockTransferScalarPerVectors,
2585  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2586  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2587  IndexType,
2588  1, // ScatterDim
2589  true, // OutputScatter: false, only use scatter weights
2590  scatter_weight_idx // ScatterWeightIdx: ascale
2591  >{c_ds_desc_refs,
2592  idx_c_ds_block_begin,
2593  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2594  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2595  c_element_op};
2596 
2597  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2598  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2599  // space filling curve for threadwise C in VGPR
2600  constexpr auto sfc_c_vgpr =
2603  Sequence<CShuffleMXdlPerWavePerShuffle,
2604  CShuffleNXdlPerWavePerShuffle,
2605  1,
2606  1,
2607  1,
2608  N2,
2609  1,
2610  N4>>{};
2611 
2612  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2613 
2614  // space filling curve for shuffled blockwise C/D/E
2615  constexpr auto sfc_cde_block =
2618  Sequence<1,
2619  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2620  1,
2621  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2622 
2623  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2624  constexpr auto EMThreads =
2625  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2626  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2627  constexpr auto ENThreads =
2628  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2629  static_for<0, num_access, 1>{}([&](auto access_id) {
2630  // make sure it's safe to write to LDS
2632  scatter_offsets; //= p_sorted_token_ids[c_token_pos];
2633 
2634  auto dstidx = sfc_cde_block.GetIndex(access_id);
2635  const index_t c_token_pos =
2636  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2637  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2638  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2639  index_t token_offset = fused_token & 0xffffff;
2640  if constexpr(IsInputGemm)
2641  {
2642  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2643  }
2644  scatter_offsets(m0) =
2645  token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1);
2646  });
2647 
2648  block_sync_lds();
2649 
2650  // each thread write its data from VGPR to LDS
2651  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2652  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2653  c_thread_buf,
2654  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2655  c_shuffle_block_buf);
2656 
2657  // make sure it's safe to read from LDS
2658  block_sync_lds();
2659 
2660  // each block copy its data from LDS to global
2661  cde_block_copy_lds_and_global.Run(
2662  c_ds_desc_refs,
2663  c_ds_buf_refs,
2664  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2665  tie(c_grid_buf),
2666  scatter_offsets);
2667 
2668  if constexpr(access_id < num_access - 1)
2669  {
2670  constexpr auto cde_lds_and_global_step =
2671  sfc_cde_block.GetForwardStep(access_id);
2672 
2673  // move on Ds
2674  static_for<0, NumDTensor, 1>{}([&](auto i) {
2675  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2676  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2677  });
2678 
2679  // move on E
2680  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2681  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2682  I0,
2683  cde_lds_and_global_step);
2684  }
2685  });
2686  }
2687  }
2688 };
2689 
2690 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:251
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:54
Definition: ck.hpp:270
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:835
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:279
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Block GEMM pipeline version enumeration.
Definition: scheduler_enum.hpp:17
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
int64_t long_index_t
Definition: ck.hpp:302
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
@ Even
Even number of iterations.
@ Odd
Odd number of iterations.
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:832
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
@ Intrawave
Schedule within a single wavefront.
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
Definition: gridwise_moe_gemm_blockscale.hpp:673
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm_blockscale.hpp:734
const ADataType * p_a_grid
Definition: gridwise_moe_gemm_blockscale.hpp:732
CDataType * p_c_grid
Definition: gridwise_moe_gemm_blockscale.hpp:735
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm_blockscale.hpp:729
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:741
const index_t * p_max_token_id
Definition: gridwise_moe_gemm_blockscale.hpp:731
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm_blockscale.hpp:730
const AScaleType * p_a_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:737
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:740
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm_blockscale.hpp:674
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:742
const BDataType * p_b_grid
Definition: gridwise_moe_gemm_blockscale.hpp:733
const BScaleType * p_b_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:738
Definition: gridwise_moe_gemm_blockscale.hpp:608
index_t K
Definition: gridwise_moe_gemm_blockscale.hpp:655
index_t AK0
Definition: gridwise_moe_gemm_blockscale.hpp:665
index_t MBlock
Definition: gridwise_moe_gemm_blockscale.hpp:667
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm_blockscale.hpp:609
index_t StrideB
Definition: gridwise_moe_gemm_blockscale.hpp:657
index_t KRead
Definition: gridwise_moe_gemm_blockscale.hpp:663
index_t StrideC
Definition: gridwise_moe_gemm_blockscale.hpp:659
index_t TopK
Definition: gridwise_moe_gemm_blockscale.hpp:652
index_t KPadded
Definition: gridwise_moe_gemm_blockscale.hpp:664
index_t NumTokens
Definition: gridwise_moe_gemm_blockscale.hpp:651
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm_blockscale.hpp:658
index_t N
Definition: gridwise_moe_gemm_blockscale.hpp:654
index_t StrideA
Definition: gridwise_moe_gemm_blockscale.hpp:656
index_t NBlock
Definition: gridwise_moe_gemm_blockscale.hpp:668
index_t BK0
Definition: gridwise_moe_gemm_blockscale.hpp:666
__host__ void Print() const
Definition: gridwise_moe_gemm_blockscale.hpp:640
index_t M
Definition: gridwise_moe_gemm_blockscale.hpp:653
index_t NPadded
Definition: gridwise_moe_gemm_blockscale.hpp:662
index_t KBatch
Definition: gridwise_moe_gemm_blockscale.hpp:660
index_t MPadded
Definition: gridwise_moe_gemm_blockscale.hpp:661
Definition: gridwise_moe_gemm_blockscale.hpp:746
index_t a_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:782
index_t ascale_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:784
index_t bscale_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:785
index_t b_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:783
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm_blockscale.hpp:747
Definition: gridwise_moe_gemm_blockscale.hpp:178
static constexpr auto I2
Definition: gridwise_moe_gemm_blockscale.hpp:184
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm_blockscale.hpp:218
static constexpr auto I7
Definition: gridwise_moe_gemm_blockscale.hpp:189
static constexpr auto I3
Definition: gridwise_moe_gemm_blockscale.hpp:185
static constexpr auto BK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:197
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:302
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm_blockscale.hpp:424
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:309
static constexpr index_t KLane
Definition: gridwise_moe_gemm_blockscale.hpp:216
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t K, index_t KBatch)
Definition: gridwise_moe_gemm_blockscale.hpp:253
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:318
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:246
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1945
static constexpr index_t NLane
Definition: gridwise_moe_gemm_blockscale.hpp:219
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:274
remove_cvref_t< decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm &&!IsSplitK >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm_blockscale.hpp:955
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm_blockscale.hpp:957
static constexpr auto I1
Definition: gridwise_moe_gemm_blockscale.hpp:183
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:278
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1154
float BScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:180
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm_blockscale.hpp:343
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:531
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm_blockscale.hpp:198
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))> DsGridDesc_M_N
Definition: gridwise_moe_gemm_blockscale.hpp:605
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:264
static constexpr auto AK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:196
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm_blockscale.hpp:237
static constexpr auto I5
Definition: gridwise_moe_gemm_blockscale.hpp:187
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm_blockscale.hpp:582
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm_blockscale.hpp:235
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:239
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm_blockscale.hpp:224
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm_blockscale.hpp:913
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:283
static constexpr auto I4
Definition: gridwise_moe_gemm_blockscale.hpp:186
float AScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:179
static constexpr index_t KPack
Definition: gridwise_moe_gemm_blockscale.hpp:203
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:288
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm_blockscale.hpp:906
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:594
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:522
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:981
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm_blockscale.hpp:200
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:537
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm_blockscale.hpp:788
static constexpr auto AK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:194
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm_blockscale.hpp:434
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1161
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm_blockscale.hpp:191
static constexpr index_t KGroup
Definition: gridwise_moe_gemm_blockscale.hpp:205
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1190
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:329
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm_blockscale.hpp:222
static constexpr auto BK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:195
static constexpr auto I0
Definition: gridwise_moe_gemm_blockscale.hpp:182
static constexpr index_t NWave
Definition: gridwise_moe_gemm_blockscale.hpp:220
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:561
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:1169
static constexpr auto I6
Definition: gridwise_moe_gemm_blockscale.hpp:188
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:269
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:323
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:295
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1255
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1861
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1808
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1855
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:440
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: amd_ck_fp8.hpp:36
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:270
Definition: unary_element_wise_operation.hpp:1041
Definition: unary_element_wise_operation.hpp:340
Definition: unary_element_wise_operation.hpp:1087