/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File
gridwise_moe_mx_gemm_bpreshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
16 
20 
21 #define DEBUG_LOG 0
22 
23 namespace ck {
24 
25 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26 // kernel function Blockers:
27 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28 // two lds chunks.
29 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30 // buffer when we declare __shared__ inside blkgemmpipe
31 
33 {
34  gelu_and_mul = 0,
35  silu_and_mul = 1
36 };
37 
38 template <typename GridwiseGemm,
39  bool HasMainKBlockLoop,
40  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
41  index_t MinimumOccupancy = 1,
42  TailNumber TailNum = TailNumber::Even>
43 __global__ void
44 #if CK_USE_LAUNCH_BOUNDS
45 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
46 #endif
47  // __attribute__((amdgpu_waves_per_eu(1, 1)))
48  kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
49 {
50 #if defined(__gfx9__)
51  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
52  {
53  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
54 
55  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
56 
57  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
58  karg.p_sorted_token_ids,
59  karg.p_sorted_expert_ids,
60  karg.p_max_token_id,
61  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
62  karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
63  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
64  karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
65  karg.p_ds_grid,
66  karg.p_c_grid,
67  p_shared,
68  karg,
69  karg.a_element_op,
70  karg.b_element_op,
71  karg.c_element_op);
72  }
73 #else
74  ignore = karg;
75 #endif // end of if (defined(__gfx9__))
76 }
77 
78 template <typename GridwiseGemm,
79  bool HasMainKBlockLoop,
80  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
81  index_t MinimumOccupancy = 1,
82  TailNumber TailNum = TailNumber::Even>
83 __global__ void
84 #if CK_USE_LAUNCH_BOUNDS
85 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
86 #endif
87  // __attribute__((amdgpu_waves_per_eu(1, 1)))
88  kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
89 {
90 #if defined(__gfx9__)
91  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
92  {
93  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
94  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
95 
96  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
97 
98  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
99  karg.p_sorted_token_ids,
100  karg.p_sorted_expert_ids,
101  karg.p_max_token_id,
102  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
103  karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
104  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
105  karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
106  karg.p_ds_grid,
107  karg.p_c_grid,
108  p_shared_0,
109  p_shared_1,
110  karg,
111  karg.a_element_op,
112  karg.b_element_op,
113  karg.c_element_op);
114  }
115 #else
116  ignore = karg;
117 #endif // end of if (defined(__gfx9__))
118 }
119 
120 template <typename ALayout,
121  typename BLayout,
122  typename DsLayout,
123  typename CLayout,
124  typename ADataType,
125  typename AScaleDataType,
126  typename BDataType,
127  typename BScaleDataType,
128  typename AccDataType,
129  typename CShuffleDataType,
130  typename DsDataType,
131  typename CDataType,
132  typename AElementwiseOperation,
133  typename BElementwiseOperation,
134  typename CElementwiseOperation,
136  index_t ScaleBlockSize, // Scaling block size
137  index_t BlockSize, // Thread block size
138  index_t MPerBlock,
139  index_t NPerBlock,
140  index_t KPerBlock,
141  index_t AK1Value,
142  index_t BK1Value,
143  index_t MPerXdl,
144  index_t NPerXdl,
145  index_t MXdlPerWave,
146  index_t NXdlPerWave,
147  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
148  typename ABlockTransferThreadClusterArrangeOrder,
149  typename ABlockTransferSrcAccessOrder,
150  index_t ABlockTransferSrcVectorDim,
151  index_t ABlockTransferSrcScalarPerVector,
152  index_t ABlockTransferDstScalarPerVector_AK1,
153  bool AThreadTransferSrcResetCoordinateAfterRun,
154  index_t ABlockLdsExtraM,
155  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
156  typename BBlockTransferThreadClusterArrangeOrder,
157  typename BBlockTransferSrcAccessOrder,
158  index_t BBlockTransferSrcVectorDim,
159  index_t BBlockTransferSrcScalarPerVector,
160  index_t BBlockTransferDstScalarPerVector_BK1,
161  bool BThreadTransferSrcResetCoordinateAfterRun,
162  index_t BBlockLdsExtraN,
163  index_t CShuffleMXdlPerWavePerShuffle,
164  index_t CShuffleNXdlPerWavePerShuffle,
165  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
166  typename CDEShuffleBlockTransferScalarPerVectors,
169  index_t ActivationOperation = 0,
170  bool NSwizzle = false,
171  bool IsInputGemm = true,
172  bool MulRoutedWeight = true,
173  typename IndexType = index_t,
174  typename ComputeTypeA = ADataType,
175  typename ComputeTypeB = BDataType>
177 {
178  using LDSTypeA = ADataType;
179  using LDSTypeB = BDataType;
180 
181  static constexpr auto I0 = Number<0>{};
182  static constexpr auto I1 = Number<1>{};
183  static constexpr auto I2 = Number<2>{};
184  static constexpr auto I3 = Number<3>{};
185  static constexpr auto I4 = Number<4>{};
186  static constexpr auto I5 = Number<5>{};
187  static constexpr auto I6 = Number<6>{};
188  static constexpr auto I7 = Number<7>{};
189  static constexpr auto I8 = Number<8>{};
190  static constexpr auto I9 = Number<9>{};
191 
193  CDEShuffleBlockTransferScalarPerVectors{}[I0];
194  // K1 should be Number<...>
195  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
196  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
197  static constexpr auto AK1Number = Number<AK1Value>{};
198  static constexpr auto BK1Number = Number<BK1Value>{};
199 
200  static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
201  static constexpr bool is_single_rate_mfma = false;
202  static constexpr auto is_scale_mfma = true;
203 
204  static constexpr index_t NumDTensor = DsDataType::Size();
205 
206  static constexpr auto MXdlPack = 2;
207  static constexpr auto NXdlPack = 2;
208  static constexpr auto KXdlPack = 2;
209 
210  //> KPack is at least the k_per_blk of selected mfma
211  //
212  // Should be a multiple of k_per_blk.
213  // TODO: Move this to blockwise pipeline base
214  // KPack in packed data types for pk A/B
215 
216  static constexpr index_t APackedSize = packed_size_v<ADataType>;
217  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
218 
219  using mfma_selector = MfmaSelector<ComputeTypeA,
220  MPerXdl,
221  NPerXdl,
222  ComputeTypeB,
224  is_scale_mfma>;
225  static constexpr index_t KPack =
227 
228  static constexpr index_t NLane = NPerXdl;
229  static constexpr index_t KLane = 64 / NLane;
230  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
231  static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
232 
233  // static constexpr index_t NumTokens = 1;
234  static constexpr index_t SortedTileSize = MPerBlock;
235 
237  static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
238  static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
239  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
240  "A scale pack data type too large!");
241  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
242  "B scale pack data type too large!");
243 
244  static constexpr auto MakeDsGridPointer()
245  {
246  return generate_tuple(
247  [&](auto i) {
248  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
249 
250  return static_cast<const DDataType*>(nullptr);
251  },
253  }
254 
255  using DsGridPointer = decltype(MakeDsGridPointer());
256 
258 
259  __host__ static auto CalculateGridSize(index_t M, index_t N)
260  {
261  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
262  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
263  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
264  const index_t gridy = NSwizzle ? 1 : mblock;
265 
266  return std::make_tuple(gridx, gridy, 1);
267  }
268 
269  __host__ static auto CalculateMPadded(index_t M)
270  {
271  return math::integer_least_multiple(M, MPerBlock);
272  }
273 
274  __host__ static auto CalculateNPadded(index_t N)
275  {
276  return math::integer_least_multiple(N, NPerBlock);
277  }
278 
279  __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
280  {
281  return math::integer_divide_ceil(N, NLane);
282  }
283  __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
284  {
286  }
287 
288  __host__ static auto CalculateKPadded(index_t K)
289  {
290  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
291  }
292 
293  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
294  {
295  auto K_t = K_Batch * KPerBlock;
296  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
297  }
298 
299  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
300  {
301  auto K_t = K_Batch * KPerBlock;
302  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
303  }
304 
305  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
306  {
307  auto K_t = K_Batch * KPerBlock;
308  return (K + K_t - 1) / K_t * KPerBlock;
309  }
310 
311  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
312  {
313  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
314  auto K_t = K_Batch * KReadVec;
315  return (K + K_t - 1) / K_t * KReadVec;
316  }
317 
318  __host__ static auto CalculateMBlock(index_t M)
319  {
320  return math::integer_divide_ceil(M, MPerBlock);
321  }
322 
323  __host__ static auto CalculateNBlock(index_t N)
324  {
325  return math::integer_divide_ceil(N, NPerBlock);
326  }
327 
328  template <index_t MNXdlPerWave,
329  index_t MNWaves,
330  index_t MNXdlPack,
331  index_t MNPerXdl,
332  bool IsXor,
333  typename TileDesc_K0_MN_K1>
334  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
335  {
336  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
337  constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
338  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
339 
340  if constexpr(IsXor)
341  {
342  constexpr auto permuted_desc = transform_tensor_descriptor(
343  TileDesc_K0_MN_K1{},
348 
350  permuted_desc,
351  make_tuple(
354  Number<MNWaves>{},
356  Number<MNPerXdl>{}))),
359  }
360  else
361  {
363  TileDesc_K0_MN_K1{},
364  make_tuple(
367  Number<MNWaves>{},
369  Number<MNPerXdl>{}))),
372  }
373  }
374 
375  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
376  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
377  {
378  const auto a_grid_desc_mraw_kraw = [&]() {
379  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
380  {
381  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
382  }
383  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
384  {
385  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
386  }
387  }();
388 
390 
391  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
392  GemmSpec == GemmSpecialization::MNKPadding)
393  {
394  // pad both M and K
395  const auto a_grid_desc_m_k =
396  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
398  make_right_pad_transform(K, KPad - K)),
401 
402  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
403  a_grid_desc_m_k,
408 
409  return a_grid_desc_ak0_m_ak1;
410  }
411  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
412  GemmSpec == GemmSpecialization::MNPadding)
413  {
414  // pad M, but not K
415  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
416  a_grid_desc_mraw_kraw,
418  make_right_pad_transform(M, MPad - M)),
421 
422  return a_grid_desc_ak0_m_ak1;
423  }
424  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
425  GemmSpec == GemmSpecialization::NKPadding)
426  {
427  // pad K, but not M
428  const auto a_grid_desc_m_k = transform_tensor_descriptor(
429  a_grid_desc_mraw_kraw,
433 
434  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
435  a_grid_desc_m_k,
440 
441  return a_grid_desc_ak0_m_ak1;
442  }
443  else
444  {
445  // not pad M or K
446  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
447  a_grid_desc_mraw_kraw,
448  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
452 
453  const auto a_grid_desc_permuted = transform_tensor_descriptor(
454  a_grid_desc_ak0_m_ak1,
457  make_pass_through_transform(AK1Value)),
460 
461  const auto a_grid_desc = transform_tensor_descriptor(
462  a_grid_desc_permuted,
463  make_tuple(
466  make_pass_through_transform(AK1Value)),
469 
470  return a_grid_desc;
471  }
472  }
473 
474  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
475  {
476  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
477  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
478  constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack>{};
480  make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
481  }
482 
483  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
484  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
485  {
486  const auto b_grid_desc_nraw_kraw = [&]() {
488  {
489  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
490  }
492  {
493  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
494  }
495  }();
496 
498 
499  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
500  GemmSpec != GemmSpecialization::Default),
501  "pk_i4_t does not support padding");
502  static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
503  GemmSpec != GemmSpecialization::Default),
504  "f4x2_pk_t does not support padding");
505 
506  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
507  GemmSpec == GemmSpecialization::MNKPadding)
508  {
509  // pad both N and K
510  const auto b_grid_desc_n_k =
511  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
513  make_right_pad_transform(K, KPad - K)),
516 
517  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
518  b_grid_desc_n_k,
523 
524  return b_grid_desc_bk0_n_bk1;
525  }
526  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
527  GemmSpec == GemmSpecialization::MNPadding)
528  {
529  // pad N, but not K
530  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
531  b_grid_desc_nraw_kraw,
533  make_right_pad_transform(N, NPad - N)),
536 
537  return b_grid_desc_bk0_n_bk1;
538  }
539  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
540  GemmSpec == GemmSpecialization::MKPadding)
541  {
542  // pad K, but not N
543  const auto b_grid_desc_n_k = transform_tensor_descriptor(
544  b_grid_desc_nraw_kraw,
548 
549  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
550  b_grid_desc_n_k,
555 
556  return b_grid_desc_bk0_n_bk1;
557  }
558  else
559  {
560  // not pad N or K
561  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
562  b_grid_desc_nraw_kraw,
563  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
567 
568  const auto b_grid_desc_permuted = transform_tensor_descriptor(
569  b_grid_desc_bk0_n_bk1,
572  make_pass_through_transform(BK1Value)),
575 
576  const auto b_grid_desc = transform_tensor_descriptor(
577  b_grid_desc_permuted,
578  make_tuple(
581  make_pass_through_transform(BK1Value)),
584 
585  return b_grid_desc;
586  }
587  }
588 
589  template <typename ABlockDesc_AK0_M_AK1>
590  __host__ __device__ static constexpr auto
591  MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
592  {
593  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
594 
595  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl, true>(
596  ABlockDesc_AK0_M_AK1{});
597  }
598 
599  template <typename BBlockDesc_BK0_N_BK1>
600  __host__ __device__ static constexpr auto
601  MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
602  {
603  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
604 
605  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl, false>(
606  BBlockDesc_BK0_N_BK1{});
607  }
608 
609  template <typename ELayout>
610  __host__ __device__ static auto MakeCGridDescriptor_M_N(
611  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
612  {
613  const auto c_grid_desc_mraw_nraw = [&]() {
615  {
616  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
617  }
619  {
620  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
621  }
622  }();
623 
624  // pad M and N
625  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
627  make_right_pad_transform(N, NPad - N)),
630  }
631 
632  template <typename DLayout>
633  __host__ __device__ static auto
635  {
636  const auto c_grid_desc_mraw_nraw = [&]() {
638  {
639  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
640  }
642  {
643  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
644  }
645  }();
646 
647  // pad M and N
648  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
650  make_right_pad_transform(N, NPad - N)),
653  }
654 
655  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
656  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
657  {
658  return generate_tuple(
659  [&](auto i) {
660  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
661  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
662  },
664  }
665 
666  template <typename DsGridDesc>
668  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
669  {
670  return generate_tuple(
671  [&](auto i) {
673  ds_grid_desc_m_n[i], MBlock, NBlock);
674  },
676  }
677 
678  struct Problem
679  {
680  __host__ Problem(index_t NumTokens_,
681  index_t TopK_,
682  index_t M_,
683  index_t N_,
684  index_t K_,
685  index_t StrideA_,
686  index_t StrideScaleA_,
687  index_t StrideB_,
688  index_t StrideScaleB_,
689  std::array<index_t, NumDTensor> StrideDs_,
690  index_t StrideC_,
691  index_t KBatch_)
692  : NumTokens{NumTokens_},
693  TopK{TopK_},
694  M{M_},
695  N{N_},
696  K{K_},
697  StrideA{StrideA_},
698  StrideScaleA{StrideScaleA_},
699  StrideB{StrideB_},
700  StrideScaleB{StrideScaleB_},
701  StrideDs{StrideDs_},
702  StrideC{StrideC_},
703  KBatch{KBatch_},
706  KRead{CalculateKRead(K_, KBatch_)},
707  KPadded{CalculateKPadded(K_, KBatch_)},
708  AK0{CalculateAK0Padded(K_, KBatch_)},
709  BK0{CalculateBK0Padded(K_, KBatch_)},
710  MBlock{CalculateMBlock(M_)},
712  {
713  }
714 
715  __host__ void Print() const
716  {
717  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
718  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
719  << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
720  << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
721  << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
722  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
723  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
724  << ", " << "NBlock: " << NBlock << "}" << std::endl;
725  }
726 
736  std::array<index_t, NumDTensor> StrideDs;
747  };
748 
749  // Argument
751  {
752  __host__ Argument(const index_t* p_sorted_token_ids_,
753  const index_t* p_sorted_expert_ids_,
754  const index_t* p_max_token_id_,
755  const ADataType* p_a_grid_,
756  const AScaleDataType* p_a_scale_grid_,
757  const BDataType* p_b_grid_,
758  const BScaleDataType* p_b_scale_grid_,
759  std::array<const void*, NumDTensor> p_ds_grid_,
760  CDataType* p_c_grid_,
761  index_t NumTokens_,
762  index_t TopK_,
763  index_t M_,
764  index_t N_,
765  index_t K_,
766  index_t StrideA_,
767  index_t StrideScaleA_,
768  index_t StrideB_,
769  index_t StrideScaleB_,
770  std::array<index_t, NumDTensor> StrideDs_,
771  index_t StrideC_,
772  index_t k_batch_,
773  AElementwiseOperation a_element_op_,
774  BElementwiseOperation b_element_op_,
775  CElementwiseOperation c_element_op_)
776  : Problem{NumTokens_,
777  TopK_,
778  M_,
779  N_,
780  K_ / APackedSize,
781  StrideA_ / APackedSize,
782  StrideScaleA_,
783  StrideB_ / BPackedSize,
784  StrideScaleB_,
785  StrideDs_,
786  StrideC_,
787  k_batch_},
788  p_sorted_token_ids{p_sorted_token_ids_},
789  p_sorted_expert_ids{p_sorted_expert_ids_},
790  p_max_token_id{p_max_token_id_},
791  p_a_grid{p_a_grid_},
792  p_a_scale_grid{p_a_scale_grid_},
793  p_b_grid{p_b_grid_},
794  p_b_scale_grid{p_b_scale_grid_},
795  p_ds_grid{},
796  p_c_grid{p_c_grid_},
797  a_element_op{a_element_op_},
798  b_element_op{b_element_op_},
799  c_element_op{c_element_op_}
800  {
801 
802  // populate pointer, desc for Ds
803  static_for<0, NumDTensor, 1>{}([&](auto i) {
804  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
805 
806  // D pointer
807  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
808  });
809  }
810 
814  const ADataType* p_a_grid;
815  const AScaleDataType* p_a_scale_grid;
816  const BDataType* p_b_grid;
817  const BScaleDataType* p_b_scale_grid;
819  CDataType* p_c_grid;
820 
821  const AElementwiseOperation a_element_op;
822  const BElementwiseOperation b_element_op;
823  const CElementwiseOperation c_element_op;
824  };
825 
827  {
828  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
829  {
830  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
831  {
832  a_k_split_offset = k_id * karg.KRead;
833  }
834  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
835  {
836  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
837  }
838 
839  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
840  {
841  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
842  }
843  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
844  {
845  // KPack * NLane * KLane * K0 * N0
846  b_k_split_offset = k_id * karg.KRead * NPerXdl;
847  }
848 
849  // Calculate A scale offset
850  a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack *
851  MPerXdl / scale_pack_size_a;
852 
853  // Calculate B scale offset
854  b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack *
855  NPerXdl / scale_pack_size_b;
856 
857  if(k_id < karg.KBatch - 1)
858  {
859  karg.K = karg.KRead;
860  }
861  else
862  {
863  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
864  }
865  }
866 
871  };
872 
873  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
874  {
875  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
876  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
877 
878  // A matrix in LDS memory, dst of blockwise copy
879  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
880  {
881  // contiguous in LDS
885  }
886  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
887  // in some cases.
889  {
890  constexpr auto a_lds_block_desc =
893 
894  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
895  a_lds_block_desc,
901 
902  return a_lds_block_desc_permuted;
903  }
904  else // ColumnMajor A
905  {
906  // kfold and mpair dimension is not always required.
907  // more dimension in merge_transform increase the difficulty of generating immarg offset
908  // for compiler.
909  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
910  constexpr auto M1 = MPerBlock / M0;
911 
912  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
913  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
914  constexpr auto KThreadRead = WaveSize / MPerXdl;
915  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
916 
917  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
918  ? 1
919  : 128 / (AK1Number * M0 * sizeof(ADataType));
920  constexpr auto KThreadReadPerm =
921  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
922  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
923  : KThreadRead;
924 
925  // 1<=mpair<=n0
926  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
927  ? 1
928  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
929  ? M0
930  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
931 
932  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
936  Number<kfold * M0 / mpair>{},
937  Number<mpair>{},
938  AK1Number));
939 
940  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
941  a_lds_block_desc,
942  make_tuple(
946  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
949  make_tuple(
951  make_tuple(
953 
954  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
955  a_lds_block_desc_permuted,
956  make_tuple(
964  Sequence<1>{},
965  Sequence<2>{},
966  Sequence<3>{},
967  Sequence<4>{},
968  Sequence<5>{}),
970  Sequence<2>{},
971  Sequence<0, 3>{},
972  Sequence<4, 5>{},
973  Sequence<6>{},
974  Sequence<7>{}));
975 
976  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
977  a_lds_block_desc_unmerged,
980  Number<KThreadWrite / kfold / KThreadReadPerm>{},
981  Number<kfold>{},
988 
989  return a_lds_block_desc_ak0_m_ak1;
990  }
991  }
992 
993  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
994  {
995  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
997  I1,
999  Number<KRepeat>{},
1000  Number<BK1Value>{}));
1001  }
1002 
1004  {
1005  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1006 
1007  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1009  make_tuple(I1,
1011  I1,
1013 
1014  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1015  }
1016 
1019  BlkGemmPipelineVer,
1020  BlkGemmPipeSched,
1021  BlockSize,
1022  ScaleBlockSize,
1023  ADataType,
1024  AScaleDataType,
1025  BDataType,
1026  BScaleDataType,
1027  ComputeTypeA,
1028  AccDataType,
1035  ABlockTransferSrcScalarPerVector,
1036  BBlockTransferSrcScalarPerVector,
1037  MPerBlock,
1038  NPerBlock,
1039  KPerBlock,
1040  MPerXdl,
1041  NPerXdl,
1042  MXdlPerWave,
1043  NXdlPerWave,
1044  KPack,
1045  IsInputGemm>())>;
1046 
1047  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1048  {
1049  // LDS allocation for A and B: be careful of alignment
1050  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1051  // lds max alignment
1052  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1053 
1054  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1055  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1056 
1057  // LDS allocation for C shuffle in LDS
1058  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1060 
1061  constexpr auto c_block_size =
1062  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1063 
1064  return math::max(a_block_space_size_aligned * sizeof(ADataType),
1065  c_block_size * sizeof(CShuffleDataType));
1066  }
1067 
1069 
1070  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1071  __host__ static constexpr bool CheckValidity(const Argument& karg)
1072  {
1073  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1074  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1075  "Invalid tuning param!");
1076 
1077  static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1078  "KPerBlock should be multiple of ScaleBlockSize");
1079 
1080  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1085  {
1086  if(!(karg.M % MPerBlock == 0))
1087  {
1088  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1089  {
1090  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1091  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1092  << std::endl;
1093  }
1094  return false;
1095  }
1096  }
1097 
1098  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1103  {
1104  if(!(karg.N % NPerBlock == 0))
1105  {
1106  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1107  {
1108  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1109  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1110  << std::endl;
1111  }
1112  return false;
1113  }
1114  }
1115 
1116  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1120  {
1121  auto K_t = karg.KBatch * KPerBlock;
1122  if(!(karg.K % K_t == 0))
1123  {
1124  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1125  {
1126  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1127  << karg.K << " " << __FILE__ << ":" << __LINE__
1128  << ", in function: " << __func__ << std::endl;
1129  }
1130  return false;
1131  }
1132  }
1133  else
1134  {
1135  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1136  auto K_t = karg.KBatch * KReadVec;
1137  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1138  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1139  {
1140  return false;
1141  }
1142  }
1143 
1145  {
1146  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1147  {
1148  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1149  {
1150  std::cout << "Arg K (" << karg.K
1151  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1152  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1153  << __LINE__ << ", in function: " << __func__ << std::endl;
1154  }
1155  return false;
1156  }
1157  }
1158  else
1159  {
1160  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1161  {
1162  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1163  {
1164  std::cout << "Arg M (" << karg.M
1165  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1166  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1167  << __LINE__ << ", in function: " << __func__ << std::endl;
1168  }
1169  return false;
1170  }
1171  }
1172 
1174  {
1175  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1176  {
1177  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1178  {
1179  std::cout << "Arg N (" << karg.N
1180  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1181  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1182  << __LINE__ << ", in function: " << __func__ << std::endl;
1183  }
1184  return false;
1185  }
1186  }
1187  else
1188  {
1189  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1190  {
1191  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1192  {
1193  std::cout << "Arg K (" << karg.K
1194  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1195  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1196  << __LINE__ << ", in function: " << __func__ << std::endl;
1197  }
1198  return false;
1199  }
1200  }
1201 
1203  {
1205  {
1206  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1207  {
1208  std::cout << "Arg N (" << karg.N
1209  << ") value is not a multiple of "
1210  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1212  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1213  << std::endl;
1214  }
1215  return false;
1216  }
1217  }
1218  else
1219  {
1221  {
1222  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1223  {
1224  std::cout << "Arg M (" << karg.M
1225  << ") value is not a multiple of "
1226  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1228  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1229  << std::endl;
1230 
1231  return false;
1232  }
1233  }
1234  }
1235 
1236  // check gridwise gemm pipeline
1237 #if 0
1238  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1239 
1240  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1241  {
1242  return false;
1243  }
1244 #endif
1245  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1246  return true;
1247  }
1248 
1249  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1250  {
1251  const index_t num_loop = K / KPerBlock;
1252 
1253  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1254  }
1255 
1256  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1257  {
1258  const index_t num_loop = K / KPerBlock;
1259 
1260  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1261  }
1262 
1263  template <typename CGridDesc>
1264  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1265  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1266  {
1267  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1268  c_grid_desc_m_n,
1273 
1274  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1275  }
1276 
1277  // return block_id to C matrix tile idx (m0, n0) mapping
1278  // if arch = gfx942
1279  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1280  // NPerBlock>;
1281 
1282 #if 0
1283  template <bool HasMainKBlockLoop,
1284  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1285  TailNumber TailNum = TailNumber::Odd>
1286  __device__ static void Run(const index_t* p_sorted_token_ids,
1287  const index_t* p_sorted_expert_ids,
1288  const index_t* p_max_token_id,
1289  const ADataType* p_a_grid,
1290  const AScaleDataType* p_a_scale_grid,
1291  const BDataType* p_b_grid,
1292  const BScaleDataType* p_b_scale_grid,
1293  DsGridPointer& p_ds_grid,
1294  CDataType* p_c_grid,
1295  void* p_shared,
1296  const Problem& problem,
1297  AElementwiseOperation a_element_op,
1298  BElementwiseOperation b_element_op,
1299  CElementwiseOperation c_element_op)
1300  {
1301  ignore = b_element_op;
1302  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1303  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1304  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1305  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1306  problem.MPadded,
1307  problem.K,
1308  problem.KPadded,
1309  problem.StrideA,
1310  problem.AK0);
1311  const auto b_grid_desc_bpreshuffled =
1312  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1313  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1314  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1315  problem.MPadded,
1316  problem.N,
1317  problem.NPadded,
1318  problem.StrideC);
1319 
1320  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1321  make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock),
1322  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1323  (KXdlPack * 64 / MPerXdl),
1324  64 * KXdlPack * MXdlPack / scale_pack_size_a));
1325 
1326  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1327  make_tuple(problem.N / (NXdlPack * NPerXdl),
1328  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1329  (KXdlPack * 64 / NPerXdl),
1330  64 * KXdlPack * NXdlPack / scale_pack_size_b));
1331 
1332  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1334  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1335  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1336  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1337  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1338  if(expert_block_id * MPerBlock >= max_token_id)
1339  return;
1340  const index_t expert_id =
1341  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1342 
1343  const auto block_mn = [&]() -> std::pair<int, int> {
1344  if constexpr(NSwizzle)
1345  {
1346  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1347  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1348  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1349  const index_t expert_swizzle =
1350  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1351  const index_t bid_new = blockIdx.x - prefix_block;
1352  const index_t nid = __builtin_amdgcn_readfirstlane(
1353  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1354  const index_t mid =
1355  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1356  return {nid, mid};
1357  }
1358  else
1359  {
1360  return {blockIdx.x, blockIdx.y};
1361  }
1362  }();
1363 
1364  const index_t block_n_id = block_mn.first;
1365  const index_t block_m_id = block_mn.second;
1366  const index_t token0 =
1367  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1368 
1369  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1370  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1371  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1372  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1373  constexpr auto AKThreads = AK0Threads * AK1Threads;
1374  constexpr auto AMRepeats = MPerBlock / AMThreads;
1375  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1376 
1377  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1378  return;
1379  StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
1380  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1381  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1382  index_t token_offset = fused_token & 0xffffff;
1383  if constexpr(!IsInputGemm)
1384  {
1385  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1386  }
1387  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K / APackedSize;
1388  });
1389  const index_t expert_stride =
1390  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1391  const index_t expert_scale_stride =
1392  __builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) *
1393  math::integer_divide_ceil(problem.K, ScaleBlockSize));
1394 
1395  // N0, K0, Blocksize*KPack
1396  const index_t n_block_data_idx_on_grid =
1397  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1398 
1399  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1400  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1401  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1402  p_b_grid + expert_id * expert_stride / BPackedSize,
1403  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1404 
1405  // A, B scale buffer
1406  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1407  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1408  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1409  p_b_scale_grid + expert_id * expert_scale_stride,
1410  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1411 
1412  // A matrix in LDS memory, dst of blockwise copy
1413  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1414 
1415  // B matrix in LDS memory, dst of blockwise copy
1416  // dummy
1417  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1418  // A matrix blockwise copy
1419  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1421  AElementwiseOperation,
1424  Sequence<AK0Number, MPerBlock, AK1Number>,
1425  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1426  ABlockTransferThreadClusterArrangeOrder,
1427  ADataType,
1428  LDSTypeA,
1429  decltype(a_grid_desc_ak0_m_ak1),
1430  decltype(a_block_desc_ak0_m_ak1),
1431  ABlockTransferSrcAccessOrder,
1432  Sequence<0, 1, 2>,
1433  ABlockTransferSrcVectorDim,
1434  2,
1435  ABlockTransferSrcScalarPerVector,
1436  ABlockTransferDstScalarPerVector_AK1,
1437  1,
1438  1,
1439  AThreadTransferSrcResetCoordinateAfterRun,
1440  true,
1441  IndexType,
1442  1,
1443  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1444  make_multi_index(0, 0, 0),
1445  a_element_op,
1446  a_block_desc_ak0_m_ak1,
1447  make_multi_index(0, 0, 0),
1449  gather_offsets);
1450 
1451  // Thread-wise copy
1452  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1453  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1454  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1455 
1456  auto b_blockwise_copy =
1457  ThreadwiseTensorSliceTransfer_v2<BDataType,
1458  BDataType,
1459  decltype(b_grid_desc_bpreshuffled),
1460  decltype(b_block_desc_bk0_n_bk1),
1461  Sequence<Number<NXdlPerWave / NXdlPack>{},
1462  I1,
1463  Number<NXdlPack>{},
1464  Number<KRepeat>{},
1465  Number<BK1Value>{}>,
1466  Sequence<1, 2, 0, 3>,
1467  4,
1468  BBlockTransferSrcScalarPerVector,
1469  BThreadTransferSrcResetCoordinateAfterRun,
1470  true>(
1471  b_grid_desc_bpreshuffled,
1472  make_multi_index(n_block_data_idx_on_grid,
1474  0,
1475  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1476 
1477  // LDS allocation for A and B: be careful of alignment
1478  // Cast after lds
1479  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1480  static_cast<LDSTypeA*>(p_shared),
1481  a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize);
1482 
1483  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1484  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1485 
1486  // Blockwise GEMM pipeline
1487  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1488  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1489  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1490  decltype(c_thread_buf) c_thread_buf_up;
1491 
1492  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
1493  float,
1494  c_thread_buf.num_of_v_,
1495  c_thread_buf.s_per_v,
1496  true>
1497  c_thread_buf_fp32;
1498 
1499  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1500  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1501  KPerBlock);
1502 
1503  // a and b scale processing
1504  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1505  const auto waveId_m = wave_idx[I0];
1506  const auto waveId_n = wave_idx[I1];
1507 
1508  static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
1509 
1510  auto thread_offset_shuffled =
1511  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1512 
1513  auto a_thread_offset_m = waveId_m;
1514 
1515  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1516  AScaleDataType,
1517  AScaleDataType,
1518  decltype(a_scale_grid_desc_am_ak),
1519  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1520  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1521  Sequence<0, 1, 2>, // DimAccessOrder
1522  2, // SrcVectorDim
1523  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1524  1, // SrcScalarStrideInVector
1525  true>(a_scale_grid_desc_am_ak,
1526  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1527  0,
1528  thread_offset_shuffled / scale_pack_size_a));
1529 
1530  // B scale load
1531  auto b_thread_offset_n = waveId_n;
1532 
1533  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1534  BScaleDataType,
1535  BScaleDataType,
1536  decltype(b_scale_grid_desc_bn_ak),
1537  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1538  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1539  Sequence<0, 1, 2>, // DimAccessOrder
1540  2, // SrcVectorDim
1541  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1542  1, // SrcScalarStrideInVector
1543  true>(b_scale_grid_desc_bn_ak,
1544  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1545  0,
1546  thread_offset_shuffled / scale_pack_size_b));
1547 
1548  if constexpr(IsInputGemm)
1549  {
1550  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1551  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1552  p_b_grid_up + expert_id * expert_stride / BPackedSize,
1553  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1554  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1555  BDataType,
1556  BDataType,
1557  decltype(b_grid_desc_bpreshuffled),
1558  decltype(b_block_desc_bk0_n_bk1),
1559  Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
1560  Sequence<1, 2, 0, 3>,
1561  3,
1562  BBlockTransferSrcScalarPerVector,
1563  BThreadTransferSrcResetCoordinateAfterRun,
1564  true>(b_grid_desc_bpreshuffled,
1565  make_multi_index(n_block_data_idx_on_grid,
1567  0,
1568  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1569  const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
1570  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1571  p_b_scale_grid_up + expert_id * expert_scale_stride,
1572  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1573  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1574  BScaleDataType,
1575  BScaleDataType,
1576  decltype(b_scale_grid_desc_bn_ak),
1577  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1578  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1579  Sequence<0, 1, 2>, // DimAccessOrder
1580  2, // SrcVectorDim
1581  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1582  1, // SrcScalarStrideInVector
1583  true>(
1584  b_scale_grid_desc_bn_ak,
1585  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1586  0,
1587  thread_offset_shuffled / scale_pack_size_b));
1588 
1589  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1590  a_grid_desc_ak0_m_ak1,
1591  a_block_desc_ak0_m_ak1,
1592  a_blockwise_copy,
1593  a_grid_buf,
1594  a_block_buf,
1595  a_block_slice_copy_step,
1596  b_grid_desc_bpreshuffled,
1597  b_block_desc_bk0_n_bk1,
1598  b_blockwise_copy,
1599  b_blockwise_copy_up,
1600  b_grid_buf,
1601  b_grid_buf_up,
1602  b_block_buf,
1603  b_block_slice_copy_step,
1604  c_thread_buf,
1605  c_thread_buf_up,
1606  a_scale_grid_desc_am_ak,
1607  a_scale_thread_copy,
1608  a_scale_grid_buf,
1609  b_scale_grid_desc_bn_ak,
1610  b_scale_thread_copy,
1611  b_scale_thread_copy_up,
1612  b_scale_grid_buf,
1613  b_scale_grid_buf_up,
1614  num_k_block_main_loop);
1615  }
1616  else
1617  {
1618  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1619  a_grid_desc_ak0_m_ak1,
1620  a_block_desc_ak0_m_ak1,
1621  a_blockwise_copy,
1622  a_grid_buf,
1623  a_block_buf,
1624  a_block_slice_copy_step,
1625  b_grid_desc_bpreshuffled,
1626  b_block_desc_bk0_n_bk1,
1627  b_blockwise_copy,
1628  b_grid_buf,
1629  b_block_buf,
1630  b_block_slice_copy_step,
1631  c_thread_buf,
1632  a_scale_grid_desc_am_ak,
1633  a_scale_thread_copy,
1634  a_scale_grid_buf,
1635  b_scale_grid_desc_bn_ak,
1636  b_scale_thread_copy,
1637  b_scale_grid_buf,
1638  num_k_block_main_loop);
1639  }
1640 
1641  // shuffle C and write out
1642  {
1643  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1644  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1645  "wrong!");
1646 
1647  // TODO: hacky, fix it!
1648  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1649  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1650 
1651  // TODO: hacky, fix it!
1652  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1653  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1654  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1655 
1656  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1657  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1658  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1659  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1660  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1661  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1662  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1663  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1664 
1665  // mul scales
1666  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
1667  static_assert(M4 == 4);
1668  const index_t m1 = get_warp_local_1d_id() / NWave;
1669  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
1670 
1671  vector_type<float, 4> topk_weights; // for gemm2 only
1672  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1673  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1674  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1675  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1676  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1677  if constexpr(MulRoutedWeight)
1678  {
1679  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1680  p_ds_grid[I2] + m_pos);
1681  }
1682  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1683  constexpr index_t c_offset =
1684  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1685  make_tuple(m0, n0, m2 * M4 + m4));
1686  constexpr auto cidx = Number<c_offset>{};
1687 
1688  if constexpr(IsInputGemm) // gu fusion
1689  {
1690  if constexpr(ActivationOperation == Activation::silu_and_mul)
1691  {
1692  float gate = c_thread_buf[cidx];
1693  float up = c_thread_buf_up[cidx];
1694  if constexpr(MulRoutedWeight)
1695  {
1696  gate = gate * topk_weights.AsType<float>()[m4];
1697  up = up * topk_weights.AsType<float>()[m4];
1698  }
1699  tensor_operation::element_wise::Silu{}(gate, gate);
1700  c_thread_buf_fp32(cidx) = gate * up;
1701  }
1702  else if(ActivationOperation == Activation::gelu_and_mul)
1703  {
1704  float gate = c_thread_buf[cidx];
1705  float up = c_thread_buf_up[cidx];
1706  if constexpr(MulRoutedWeight)
1707  {
1708  gate = gate * topk_weights.AsType<float>()[m4];
1709  up = up * topk_weights.AsType<float>()[m4];
1710  }
1711  tensor_operation::element_wise::Gelu{}(gate, gate);
1712  c_thread_buf_fp32(cidx) = gate * up;
1713  }
1714  }
1715  else
1716  {
1717  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1718  if constexpr(MulRoutedWeight)
1719  {
1720  c_thread_buf_fp32(cidx) =
1721  topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
1722  }
1723  }
1724  });
1725  });
1726  });
1727  });
1728 
1729  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1731 
1732  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1733  static_cast<CShuffleDataType*>(p_shared),
1734  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1735 
1736  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1737  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1738  make_tuple(
1741  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1742  M1, // M1 = MWave
1743  M2, // M2 * M3 * M4 = MPerXdl
1744  M3,
1745  M4)),
1748  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1749  N1, // N1 = NWave
1750  N2))), // N2 = NPerXdl
1751  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
1752  make_tuple(
1753  Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
1754 
1755  // calculate origin of thread output tensor on global memory
1756  // blockwise GEMM c matrix starting index
1757  const auto c_thread_mtx_on_block =
1758  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1759 
1760  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1761  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1762 
1763  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1765  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1766  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
1767  make_tuple(Sequence<0>{}));
1768 
1769  const auto m_thread_data_on_block_idx =
1770  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1771  make_multi_index(m_thread_data_on_block));
1772 
1773  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1776  make_tuple(Sequence<0, 1, 2>{}),
1777  make_tuple(Sequence<0>{}));
1778 
1779  const auto n_thread_data_on_block_idx =
1780  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1781  make_multi_index(n_thread_data_on_block));
1782 
1783  // shuffle: threadwise copy C from VGPR to LDS
1784  auto c_thread_copy_vgpr_to_lds =
1785  ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
1786  CShuffleDataType,
1787  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1788  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1790  Sequence<CShuffleMXdlPerWavePerShuffle,
1791  CShuffleNXdlPerWavePerShuffle,
1792  I1,
1793  I1,
1794  M2,
1795  I1,
1796  M4,
1797  I1>,
1798  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
1799  7,
1800  1,
1802  1,
1803  true>{
1804  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1805  make_multi_index(0,
1806  0,
1807  m_thread_data_on_block_idx[I1],
1808  n_thread_data_on_block_idx[I1],
1809  m_thread_data_on_block_idx[I2],
1810  m_thread_data_on_block_idx[I3],
1811  m_thread_data_on_block_idx[I4],
1812  n_thread_data_on_block_idx[I2]),
1814 
1815  using EDataType = CDataType;
1816 
1817  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1818  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1819 
1820  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1822  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1823 
1824  const auto ds_grid_buf = generate_tuple(
1825  [&](auto i) {
1826  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1827  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1828  },
1829  Number<NumDTensor>{});
1830 
1831  // tuple of reference to C/Ds tensor descriptors
1832  const auto c_ds_desc_refs = concat_tuple_of_reference(
1833  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1834  generate_tie([&](auto i) -> const auto& // return type should be reference
1835  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1836  Number<NumDTensor>{}));
1837 
1838  // tuple of reference to C/Ds tensor descriptors
1839  const auto c_ds_buf_refs = concat_tuple_of_reference(
1840  tie(c_shuffle_block_buf),
1841  generate_tie([&](auto i) -> const auto& // return type should be reference
1842  { return ds_grid_buf[i]; },
1843  Number<NumDTensor>{}));
1844 
1845  // tuple of starting index of C/Ds blockwise copy
1846  const auto idx_c_ds_block_begin =
1849  [&](auto) {
1850  return make_multi_index(block_m_id, 0, block_n_id, 0);
1851  // return make_multi_index(block_work_idx[I0], 0,
1852  // block_work_idx[I1], 0);
1853  },
1854  Number<NumDTensor>{}));
1855 
1856  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1857  c_grid_desc_mblock_mperblock_nblock_nperblock;
1858 
1859  using CDEBlockTransferCluster =
1860  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1861  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1862  constexpr index_t scatter_weight_idx = 1; // hack fix felix
1863  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1865  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1866  Tuple<EDataType>,
1867  decltype(c_ds_desc_refs),
1868  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1869  CElementwiseOperation,
1870  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1871  // support arbitray type
1872  Sequence<1,
1873  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1874  1,
1875  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1876  CDEBlockTransferCluster,
1877  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1878  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1879  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1880  3, // index_t SrcVectorDim,
1881  3, // index_t DstVectorDim,
1882  CDEShuffleBlockTransferScalarPerVectors,
1885  Sequence<true>,
1887  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1888  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1889  IndexType,
1890  1, // ScatterDim
1891  true, // OutputScatter: false, only use scatter weights
1892  scatter_weight_idx // ScatterWeightIdx: ascale
1893  >{c_ds_desc_refs,
1894  idx_c_ds_block_begin,
1895  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1896  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1897  c_element_op};
1898 
1899  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1900  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1901  constexpr auto sfc_c_vgpr =
1902  SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
1903  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
1904  Sequence<CShuffleMXdlPerWavePerShuffle,
1905  CShuffleNXdlPerWavePerShuffle,
1906  1,
1907  1,
1908  M2,
1909  1,
1910  M4,
1911  1>>{};
1912 
1913  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1914 
1915  // space filling curve for shuffled blockwise C/D/E
1916  constexpr auto sfc_cde_block =
1917  SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
1918  Sequence<0, 2, 1, 3>,
1919  Sequence<1,
1920  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1921  1,
1922  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1923 
1924  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1925  constexpr auto EMThreads =
1926  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1927  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1928  constexpr auto ENThreads =
1929  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1930  static_for<0, num_access, 1>{}([&](auto access_id) {
1931  // make sure it's safe to write to LDS
1932  StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
1933 
1934  auto dstidx = sfc_cde_block.GetIndex(access_id);
1935  const index_t c_token_pos =
1936  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1937  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1938  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1939  IndexType token_offset = fused_token & 0xffffff;
1940  if constexpr(IsInputGemm)
1941  {
1942  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1943  }
1944  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
1945  });
1946 
1947  block_sync_lds();
1948 
1949  // each thread write its data from VGPR to LDS
1950  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1951  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1952  c_thread_buf_fp32,
1953  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1954  c_shuffle_block_buf);
1955 
1956  // make sure it's safe to read from LDS
1957  block_sync_lds();
1958 
1959  // each block copy its data from LDS to global
1960  cde_block_copy_lds_and_global.Run(
1961  c_ds_desc_refs,
1962  c_ds_buf_refs,
1963  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1964  tie(c_grid_buf),
1965  scatter_offsets);
1966 
1967  if constexpr(access_id < num_access - 1)
1968  {
1969  constexpr auto cde_lds_and_global_step =
1970  sfc_cde_block.GetForwardStep(access_id);
1971 
1972  // move on Ds
1973  static_for<0, NumDTensor, 1>{}([&](auto i) {
1974  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1975  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1976  });
1977 
1978  // move on E
1979  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1980  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1981  I0,
1982  cde_lds_and_global_step);
1983  }
1984  });
1985  }
1986  }
1987 #endif
1988 
1989  template <bool HasMainKBlockLoop,
1990  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1991  TailNumber TailNum = TailNumber::Odd>
1992  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1993  const index_t* p_sorted_expert_ids,
1994  const index_t* p_max_token_id,
1995  const ADataType* p_a_grid,
1996  const AScaleDataType* p_a_scale_grid,
1997  const BDataType* p_b_grid,
1998  const BScaleDataType* p_b_scale_grid,
1999  DsGridPointer& p_ds_grid,
2000  CDataType* p_c_grid,
2001  void* p_shared_0,
2002  void* p_shared_1,
2003  const Problem& problem,
2004  AElementwiseOperation a_element_op,
2005  BElementwiseOperation b_element_op,
2006  CElementwiseOperation c_element_op)
2007  {
2008  ignore = a_element_op;
2009  ignore = b_element_op;
2010  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
2011  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
2012  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2013  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2014  problem.MPadded,
2015  problem.K,
2016  problem.KPadded,
2017  problem.StrideA,
2018  problem.AK0);
2019  const auto b_grid_desc_bpreshuffled =
2020  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
2021  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2022  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2023  problem.MPadded,
2024  problem.N,
2025  problem.NPadded,
2026  problem.StrideC);
2027 
2028  // We pad the M unconditionaly for Scale
2029  const auto Padded_Scale_M =
2030  math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
2031  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
2032  make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
2033  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2034  (KXdlPack * 64 / MPerXdl),
2036  make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2037  (ScaleBlockSize / APackedSize)) *
2038  MPerXdl * MXdlPack / scale_pack_size_a,
2040  1));
2041 
2042  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
2043  make_tuple(problem.N / (NXdlPack * NPerXdl),
2044  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2045  (KXdlPack * 64 / NPerXdl),
2047  make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2048  (ScaleBlockSize / BPackedSize)) *
2049  NPerXdl * NXdlPack / scale_pack_size_b,
2051  1));
2052 
2053  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2055  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2056 
2057  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2058  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2059  if(expert_block_id * MPerBlock >= max_token_id)
2060  return;
2061  const index_t expert_id =
2062  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2063  const auto block_mn = [&]() -> std::pair<int, int> {
2064  if constexpr(NSwizzle)
2065  {
2066  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2067  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2068  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2069  const index_t expert_swizzle =
2070  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2071  const index_t bid_new = blockIdx.x - prefix_block;
2072  const index_t nid = __builtin_amdgcn_readfirstlane(
2073  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2074  const index_t mid =
2075  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2076  return {nid, mid};
2077  }
2078  else
2079  {
2080  return {blockIdx.x, blockIdx.y};
2081  }
2082  }();
2083 
2084  const index_t block_n_id = block_mn.first;
2085  const index_t block_m_id = block_mn.second;
2086  const index_t token0 =
2087  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2088 
2089  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2090  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2091  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2092  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2093  constexpr auto AKThreads = AK0Threads * AK1Threads;
2094  constexpr auto AMRepeats = MPerBlock / AMThreads;
2095  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2096 
2097  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2098  return;
2100  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2101  const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2102  index_t token_offset = fused_token & 0xffffff;
2103  if constexpr(!IsInputGemm)
2104  {
2105  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2106  }
2107  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2108  });
2109 
2110  const index_t expert_stride =
2111  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2112  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2113  problem.N * (IsInputGemm ? 2 : 1) *
2114  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2115 
2116  // N0, K0, Blocksize*KPack
2117  const index_t n_block_data_idx_on_grid =
2118  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
2119 
2120  // Gride buffer creation
2121  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2122  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2123  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2124  p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2125 
2126  // A, B scale buffer
2127  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2128  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2129  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2130  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2131  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2132 
2133  // A matrix in LDS memory, dst of blockwise copy
2134  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2135 
2136  // B matrix in LDS memory, dst of blockwise copy
2137  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2138 
2139  // A matrix blockwise direct to LDS copy
2140  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
2143  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2144  ABlockTransferThreadClusterArrangeOrder,
2145  ADataType,
2146  ADataType,
2147  decltype(a_grid_desc_ak0_m_ak1),
2148  decltype(a_block_desc_ak0_m_ak1),
2149  ABlockTransferSrcAccessOrder,
2150  ABlockTransferSrcVectorDim,
2151  2,
2152  ABlockTransferSrcScalarPerVector,
2153  IndexType,
2154  1>(a_grid_desc_ak0_m_ak1,
2155  make_multi_index(0, 0, 0),
2156  a_block_desc_ak0_m_ak1,
2157  make_multi_index(0, 0, 0),
2158  gather_offsets);
2159 
2160  // Thread-wise copy
2161  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2162  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2163  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2164  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2165  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2166  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2167 
2168  auto b_blockwise_copy =
2170  BDataType,
2171  decltype(b_grid_desc_bpreshuffled),
2172  decltype(b_block_desc_bk0_n_bk1),
2173  Sequence<Number<NXdlPerWave / NXdlPack>{},
2174  I1,
2175  Number<NXdlPack>{},
2176  Number<KRepeat>{},
2177  Number<BK1Value>{}>,
2179  4,
2180  BBlockTransferSrcScalarPerVector,
2181  BThreadTransferSrcResetCoordinateAfterRun,
2182  true>(
2183  b_grid_desc_bpreshuffled,
2184  make_multi_index(n_block_data_idx_on_grid,
2186  0,
2187  0,
2188  KPack * (get_thread_local_1d_id() % WarpSize)));
2189 
2190  // LDS allocation for A and B: be careful of alignment
2191  // Cast after lds
2192  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2193  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2194  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2195  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2196  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2197 
2198  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2199  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
2200 
2201  // Blockwise GEMM pipeline
2202  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2203  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2204  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2205  decltype(c_thread_buf) c_thread_buf_up;
2206 
2208  float,
2209  c_thread_buf.num_of_v_,
2210  c_thread_buf.s_per_v,
2211  true>
2212  c_thread_buf_fp32;
2213 
2214  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2215  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2216  KPerBlock);
2217 
2218  // a and b scale processing
2219  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2220  const auto waveId_m = wave_idx[I0];
2221  const auto waveId_n = wave_idx[I1];
2222 
2223  auto thread_offset_shuffled =
2224  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2225 
2226  auto a_thread_offset_m = waveId_m;
2227 
2228  // get each thread's offset int the scale tensor
2229  const index_t token_scale_pos = block_m_id * MPerBlock;
2230  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2231  return;
2232 
2233  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2234  AScaleDataType,
2235  AScaleDataType,
2236  decltype(a_scale_grid_desc_am_ak),
2237  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2238  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2239  Sequence<0, 1, 2>, // DimAccessOrder
2240  2, // SrcVectorDim
2241  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2242  1, // SrcScalarStrideInVector
2243  true>(a_scale_grid_desc_am_ak,
2244  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2245  0,
2246  thread_offset_shuffled / scale_pack_size_a));
2247 
2248  // B scale load
2249  auto b_thread_offset_n = waveId_n;
2250 
2251  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2252  BScaleDataType,
2253  BScaleDataType,
2254  decltype(b_scale_grid_desc_bn_ak),
2255  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2256  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2257  Sequence<0, 1, 2>, // DimAccessOrder
2258  2, // SrcVectorDim
2259  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2260  1, // SrcScalarStrideInVector
2261  true>(b_scale_grid_desc_bn_ak,
2262  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2263  0,
2264  thread_offset_shuffled / scale_pack_size_b));
2265 
2266  if constexpr(IsInputGemm)
2267  {
2268  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2269  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2270  p_b_grid_up + expert_id * expert_stride,
2271  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2272  auto b_blockwise_copy_up =
2274  BDataType,
2275  decltype(b_grid_desc_bpreshuffled),
2276  decltype(b_block_desc_bk0_n_bk1),
2277  Sequence<Number<NXdlPerWave / NXdlPack>{},
2278  I1,
2279  Number<NXdlPack>{},
2280  Number<KRepeat>{},
2281  Number<BK1Value>{}>,
2283  4,
2284  BBlockTransferSrcScalarPerVector,
2285  BThreadTransferSrcResetCoordinateAfterRun,
2286  true>(
2287  b_grid_desc_bpreshuffled,
2288  make_multi_index(n_block_data_idx_on_grid,
2290  0,
2291  0,
2292  KPack * (get_thread_local_1d_id() % WarpSize)));
2293  const BScaleDataType* p_b_scale_grid_up =
2294  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
2295  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2296  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
2297  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2298 
2299  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2300  BScaleDataType,
2301  BScaleDataType,
2302  decltype(b_scale_grid_desc_bn_ak),
2303  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2304  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2305  Sequence<0, 1, 2>, // DimAccessOrder
2306  2, // SrcVectorDim
2307  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2308  1, // SrcScalarStrideInVector
2309  true>(
2310  b_scale_grid_desc_bn_ak,
2311  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2312  0,
2313  thread_offset_shuffled / scale_pack_size_b));
2314 
2315  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2316  // A
2317  a_grid_desc_ak0_m_ak1,
2318  a_block_desc_ak0_m_ak1,
2319  a_blockwise_copy,
2320  a_grid_buf,
2321  a_block_bufs,
2322  a_block_slice_copy_step,
2323  // Gate and Up
2324  b_grid_desc_bpreshuffled,
2325  b_block_desc_bk0_n_bk1,
2326  b_blockwise_copy,
2327  b_blockwise_copy_up,
2328  b_grid_buf,
2329  b_grid_buf_up,
2330  b_block_bufs,
2331  b_block_slice_copy_step,
2332  // C
2333  c_thread_buf,
2334  c_thread_buf_up,
2335  // A scale
2336  a_scale_grid_desc_am_ak,
2337  a_scale_thread_copy,
2338  a_scale_grid_buf,
2339  // B scale
2340  b_scale_grid_desc_bn_ak,
2341  b_scale_thread_copy,
2342  b_scale_thread_copy_up,
2343  b_scale_grid_buf,
2344  b_scale_grid_buf_up,
2345  num_k_block_main_loop);
2346  }
2347  else
2348  {
2349  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2350  a_grid_desc_ak0_m_ak1, // A
2351  a_block_desc_ak0_m_ak1,
2352  a_blockwise_copy,
2353  a_grid_buf,
2354  a_block_bufs,
2355  a_block_slice_copy_step,
2356  b_grid_desc_bpreshuffled, // B
2357  b_block_desc_bk0_n_bk1,
2358  b_blockwise_copy,
2359  b_grid_buf,
2360  b_block_bufs,
2361  b_block_slice_copy_step,
2362  c_thread_buf, // C
2363  a_scale_grid_desc_am_ak, // A scale
2364  a_scale_thread_copy,
2365  a_scale_grid_buf,
2366  b_scale_grid_desc_bn_ak, // B scale
2367  b_scale_thread_copy,
2368  b_scale_grid_buf,
2369  num_k_block_main_loop);
2370  }
2371 
2372  // shuffle C and write out
2373  {
2374  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2375  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2376  "wrong!");
2377  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
2378  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
2379  "wrong!");
2380 
2381  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2382 
2383  // TODO: hacky, fix it!
2384  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2385  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2386 
2387  // TODO: hacky, fix it!
2388  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2389  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2390  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2391 
2392  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2393  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2394  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2395  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2396  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2397  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2398  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2399  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2400  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
2401  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
2402 
2403  // mul scales
2404 
2405  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2406  static_assert(M5 == 4);
2407  const index_t m1 = get_warp_local_1d_id() / NWave;
2408  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
2409 
2410  vector_type<float, 4> topk_weights; // for gemm2 only
2411  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
2412  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
2413  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
2414  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
2415  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
2416  const index_t m_pos = block_m_id * MPerBlock +
2417  m0 * M2 * M1 * M3 * M4 * M5 +
2418  m1 * M2 * M3 * M4 * M5 +
2419  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2420  if constexpr(MulRoutedWeight)
2421  {
2422  topk_weights =
2423  *c_style_pointer_cast<const vector_type<float, M5>*>(
2424  p_ds_grid[I2] + m_pos);
2425  }
2426  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
2427  constexpr index_t c_offset =
2428  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2429  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2430  constexpr auto cidx = Number<c_offset>{};
2431 
2432  if constexpr(IsInputGemm) // gu fusion
2433  {
2434  if constexpr(ActivationOperation ==
2436  {
2437  float gate = c_thread_buf[cidx];
2438  float up = c_thread_buf_up[cidx];
2439  if constexpr(MulRoutedWeight)
2440  {
2441  gate = gate * topk_weights.AsType<float>()[m5];
2442  up = up * topk_weights.AsType<float>()[m5];
2443  }
2445  c_thread_buf_fp32(cidx) = gate * up;
2446  }
2447  else if(ActivationOperation == Activation::gelu_and_mul)
2448  {
2449  float gate = c_thread_buf[cidx];
2450  float up = c_thread_buf_up[cidx];
2451  if constexpr(MulRoutedWeight)
2452  {
2453  gate = gate * topk_weights.AsType<float>()[m5];
2454  up = up * topk_weights.AsType<float>()[m5];
2455  }
2457  c_thread_buf_fp32(cidx) = gate * up;
2458  }
2459  }
2460  else
2461  {
2462  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2463  if constexpr(MulRoutedWeight)
2464  {
2465  c_thread_buf_fp32(cidx) =
2466  topk_weights.AsType<float>()[m5] *
2467  c_thread_buf_fp32[cidx];
2468  }
2469  }
2470  });
2471  });
2472  });
2473  });
2474  });
2475  });
2476 
2477  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2479 
2480  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2481  static_cast<CShuffleDataType*>(p_shared_0),
2482  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2483 
2484  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2485  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2486  make_tuple(
2489  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
2490  // shuffle
2491  M1, // M1 = MWave
2492  M2, // M2 * M3 * M4 = MPerXdl
2493  M3,
2494  M4,
2495  M5)),
2499  // per shuffle
2500  N1, // N1 = NWave
2501  N2, // N2 = NXdlPack
2502  N3))), // N3 = NPerXdl
2506  Sequence<>{},
2508 
2509  // calculate origin of thread output tensor on global memory
2510  // blockwise GEMM c matrix starting index
2511  const auto c_thread_mtx_on_block =
2512  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2513 
2514  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2515  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2516 
2517  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2519  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
2521  make_tuple(Sequence<0>{}));
2522 
2523  const auto m_thread_data_on_block_idx =
2524  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2525  make_multi_index(m_thread_data_on_block));
2526 
2527  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2529  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
2531  make_tuple(Sequence<0>{}));
2532 
2533  const auto n_thread_data_on_block_idx =
2534  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2535  make_multi_index(n_thread_data_on_block));
2536 
2537  // shuffle: threadwise copy C from VGPR to LDS
2538  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2539  AccDataType,
2540  CShuffleDataType,
2541  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2542  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2544  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2545  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2546  I1,
2547  I1,
2548  M2,
2549  N2,
2550  M3,
2551  I1,
2552  M5,
2553  I1>,
2555  9,
2556  1,
2558  1,
2559  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2560  make_multi_index(0,
2561  0,
2562  m_thread_data_on_block_idx[I1],
2563  n_thread_data_on_block_idx[I1],
2564  m_thread_data_on_block_idx[I2],
2565  n_thread_data_on_block_idx[I2],
2566  m_thread_data_on_block_idx[I3],
2567  m_thread_data_on_block_idx[I4],
2568  m_thread_data_on_block_idx[I5],
2569  n_thread_data_on_block_idx[I3]),
2571 
2572  using EDataType = CDataType;
2573 
2574  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2575  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2576 
2577  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2579  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2580 
2581  const auto ds_grid_buf = generate_tuple(
2582  [&](auto i) {
2583  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2584  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2585  },
2586  Number<NumDTensor>{});
2587 
2588  // tuple of reference to C/Ds tensor descriptors
2589  const auto c_ds_desc_refs = concat_tuple_of_reference(
2590  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2591  generate_tie([&](auto i) -> const auto& // return type should be reference
2592  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2593  Number<NumDTensor>{}));
2594 
2595  // tuple of reference to C/Ds tensor descriptors
2596  const auto c_ds_buf_refs = concat_tuple_of_reference(
2597  tie(c_shuffle_block_buf),
2598  generate_tie([&](auto i) -> const auto& // return type should be reference
2599  { return ds_grid_buf[i]; },
2600  Number<NumDTensor>{}));
2601 
2602  // tuple of starting index of C/Ds blockwise copy
2603  const auto idx_c_ds_block_begin =
2606  [&](auto) {
2607  return make_multi_index(block_m_id, 0, block_n_id, 0);
2608  // return make_multi_index(block_work_idx[I0], 0,
2609  // block_work_idx[I1], 0);
2610  },
2611  Number<NumDTensor>{}));
2612 
2613  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2614  c_grid_desc_mblock_mperblock_nblock_nperblock;
2615 
2616  using CDEBlockTransferCluster =
2617  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2618  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2619  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2620  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2622  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2624  decltype(c_ds_desc_refs),
2625  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2626  CElementwiseOperation,
2627  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2628  // Sequence support
2629  // arbitray type
2630  Sequence<1,
2631  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2632  1,
2633  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2634  CDEBlockTransferCluster,
2635  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2636  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2637  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2638  3, // index_t SrcVectorDim,
2639  3, // index_t DstVectorDim,
2640  CDEShuffleBlockTransferScalarPerVectors,
2645  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2646  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2647  IndexType,
2648  1, // ScatterDim
2649  true, // OutputScatter: false, only use scatter weights
2650  scatter_weight_idx // ScatterWeightIdx: ascale
2651  >{c_ds_desc_refs,
2652  idx_c_ds_block_begin,
2653  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2654  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2655  c_element_op};
2656 
2657  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2658  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2659 
2660  constexpr auto sfc_c_vgpr =
2661  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2662  NXdlPerWave / NXdlPack,
2663  1,
2664  1,
2665  MXdlPack,
2666  NXdlPack,
2667  M2,
2668  1,
2669  M4,
2670  1>,
2672  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2673  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2674  1,
2675  1,
2676  MXdlPack,
2677  NXdlPack,
2678  M2,
2679  1,
2680  M4,
2681  1>>{};
2682 
2683  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2684 
2685  // space filling curve for shuffled blockwise C/D/E
2686  constexpr auto sfc_cde_block =
2689  Sequence<1,
2690  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2691  1,
2692  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2693 
2694  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2695  constexpr auto EMThreads =
2696  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2697  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2698  constexpr auto ENThreads =
2699  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2700  static_for<0, num_access, 1>{}([&](auto access_id) {
2701  // make sure it's safe to write to LDS
2703 
2704  auto dstidx = sfc_cde_block.GetIndex(access_id);
2705  const index_t c_token_pos =
2706  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2707  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2708  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2709  IndexType token_offset = fused_token & 0xffffff;
2710  if constexpr(IsInputGemm)
2711  {
2712  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2713  }
2714  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2715  });
2716 
2717  block_sync_lds();
2718 
2719  // each thread write its data from VGPR to LDS
2720  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2721  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2722  c_thread_buf_fp32,
2723  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2724  c_shuffle_block_buf);
2725 
2726  // make sure it's safe to read from LDS
2727  block_sync_lds();
2728 
2729  // each block copy its data from LDS to global
2730  cde_block_copy_lds_and_global.Run(
2731  c_ds_desc_refs,
2732  c_ds_buf_refs,
2733  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2734  tie(c_grid_buf),
2735  scatter_offsets);
2736 
2737  if constexpr(access_id < num_access - 1)
2738  {
2739  constexpr auto cde_lds_and_global_step =
2740  sfc_cde_block.GetForwardStep(access_id);
2741 
2742  // move on Ds
2743  static_for<0, NumDTensor, 1>{}([&](auto i) {
2744  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2745  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2746  });
2747 
2748  // move on E
2749  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2750  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2751  I0,
2752  cde_lds_and_global_step);
2753  }
2754  });
2755  }
2756  }
2757 };
2758 
2759 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp:36
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:90
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:751
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:812
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:811
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:817
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:822
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:816
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:813
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:818
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:821
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:819
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:815
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:752
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:823
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:814
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:679
index_t AK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:743
index_t NPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:740
index_t KBatch
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:738
index_t BK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:744
index_t TopK
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:728
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:680
index_t KRead
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:741
index_t K
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:731
index_t MPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:739
index_t StrideC
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:737
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:735
index_t NumTokens
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:727
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:736
index_t MBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:745
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:733
index_t N
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:730
__host__ void Print() const
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:715
index_t StrideA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:732
index_t StrideB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:734
index_t M
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:729
index_t KPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:742
index_t NBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:746
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:827
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:869
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:870
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:867
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:828
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:868
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:177
static constexpr auto I6
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:187
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:305
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:197
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:195
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1047
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:255
static constexpr auto I1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:182
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:288
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:259
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:655
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1992
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1264
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:667
static constexpr index_t NLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:228
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:993
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:234
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:238
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:198
remove_cvref_t< decltype(BlockGemmMXBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1045
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1256
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:274
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:244
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:375
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:283
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:196
static constexpr auto I9
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:190
static constexpr auto I4
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:185
static constexpr auto I3
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:184
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:311
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:293
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:269
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:192
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:318
static constexpr index_t KRepeat
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:231
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:207
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1003
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:208
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:179
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:483
static constexpr auto I0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:181
static constexpr auto I5
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:186
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:237
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:200
static constexpr index_t KLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:229
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:873
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:206
static constexpr auto I8
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:189
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:279
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:601
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:610
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:201
static constexpr auto I2
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:183
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:591
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:216
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:204
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:323
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:217
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:225
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1249
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1071
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:178
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:299
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:634
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:334
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:257
static constexpr auto I7
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:188
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:474
static constexpr index_t NWave
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:230
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:202
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:26
Definition: data_type.hpp:42
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129