/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp Source File
gridwise_moe_mx_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
16 
20 
21 #define DEBUG_LOG 0
22 
23 namespace ck {
24 
25 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26 // kernel function Blockers:
27 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28 // two lds chunks.
29 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30 // buffer when we declare __shared__ inside blkgemmpipe
31 
33 {
34  gelu_and_mul = 0,
35  silu_and_mul = 1
36 };
37 
38 #if 0
39 template <typename GridwiseGemm,
40  bool HasMainKBlockLoop,
41  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
42  index_t MinimumOccupancy = 1,
43  TailNumber TailNum = TailNumber::Even>
44 __global__ void
45 #if CK_USE_LAUNCH_BOUNDS
46 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
47 #endif
48  // __attribute__((amdgpu_waves_per_eu(1, 1)))
49  kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
50 {
51 #if defined(__gfx9__)
52  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
53  {
54  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
55 
56  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
57 
58  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
59  karg.p_sorted_token_ids,
60  karg.p_sorted_expert_ids,
61  karg.p_max_token_id,
62  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
63  karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
64  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
65  karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
66  karg.p_ds_grid,
67  karg.p_c_grid,
68  p_shared,
69  karg,
70  karg.a_element_op,
71  karg.b_element_op,
72  karg.c_element_op);
73  }
74 #else
75  ignore = karg;
76 #endif // end of if (defined(__gfx9__))
77 }
78 #endif
79 
80 template <typename GridwiseGemm,
81  bool HasMainKBlockLoop,
82  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
83  index_t MinimumOccupancy = 1,
84  TailNumber TailNum = TailNumber::Even>
85 __global__ void
86 #if CK_USE_LAUNCH_BOUNDS
87 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
88 #endif
89  // __attribute__((amdgpu_waves_per_eu(1, 1)))
90  kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
91 {
92 #if defined(__gfx9__)
93  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
94  {
95  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
96  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
97 
98  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
99 
100  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
101  karg.p_sorted_token_ids,
102  karg.p_sorted_expert_ids,
103  karg.p_max_token_id,
104  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
105  karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
106  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
107  karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
108  karg.p_ds_grid,
109  karg.p_c_grid,
110  p_shared_0,
111  p_shared_1,
112  karg,
113  karg.a_element_op,
114  karg.b_element_op,
115  karg.c_element_op);
116  }
117 #else
118  ignore = karg;
119 #endif // end of if (defined(__gfx9__))
120 }
121 
122 template <typename ALayout,
123  typename BLayout,
124  typename DsLayout,
125  typename CLayout,
126  typename ADataType,
127  typename AScaleDataType,
128  typename BDataType,
129  typename BScaleDataType,
130  typename AccDataType,
131  typename CShuffleDataType,
132  typename DsDataType,
133  typename CDataType,
134  typename AElementwiseOperation,
135  typename BElementwiseOperation,
136  typename CElementwiseOperation,
138  index_t ScaleBlockSize, // Scaling block size
139  index_t BlockSize, // Thread block size
140  index_t MPerBlock,
141  index_t NPerBlock,
142  index_t KPerBlock,
143  index_t AK1Value,
144  index_t BK1Value,
145  index_t MPerXdl,
146  index_t NPerXdl,
147  index_t MXdlPerWave,
148  index_t NXdlPerWave,
149  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
150  typename ABlockTransferThreadClusterArrangeOrder,
151  typename ABlockTransferSrcAccessOrder,
152  index_t ABlockTransferSrcVectorDim,
153  index_t ABlockTransferSrcScalarPerVector,
154  index_t ABlockTransferDstScalarPerVector_AK1,
155  bool AThreadTransferSrcResetCoordinateAfterRun,
156  index_t ABlockLdsExtraM,
157  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
158  typename BBlockTransferThreadClusterArrangeOrder,
159  typename BBlockTransferSrcAccessOrder,
160  index_t BBlockTransferSrcVectorDim,
161  index_t BBlockTransferSrcScalarPerVector,
162  index_t BBlockTransferDstScalarPerVector_BK1,
163  bool BThreadTransferSrcResetCoordinateAfterRun,
164  index_t BBlockLdsExtraN,
165  index_t CShuffleMXdlPerWavePerShuffle,
166  index_t CShuffleNXdlPerWavePerShuffle,
167  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
168  typename CDEShuffleBlockTransferScalarPerVectors,
171  index_t ActivationOperation = 0,
172  bool NSwizzle = false,
173  bool IsInputGemm = true,
174  bool MulRoutedWeight = true,
175  typename IndexType = index_t,
176  typename ComputeTypeA = ADataType,
177  typename ComputeTypeB = BDataType>
179 {
180  using LDSTypeA = ADataType;
181  using LDSTypeB = BDataType;
182 
183  static constexpr auto I0 = Number<0>{};
184  static constexpr auto I1 = Number<1>{};
185  static constexpr auto I2 = Number<2>{};
186  static constexpr auto I3 = Number<3>{};
187  static constexpr auto I4 = Number<4>{};
188  static constexpr auto I5 = Number<5>{};
189  static constexpr auto I6 = Number<6>{};
190  static constexpr auto I7 = Number<7>{};
191  static constexpr auto I8 = Number<8>{};
192  static constexpr auto I9 = Number<9>{};
193 
195  CDEShuffleBlockTransferScalarPerVectors{}[I0];
196  // K1 should be Number<...>
197  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
198  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
199  static constexpr auto AK1Number = Number<AK1Value>{};
200  static constexpr auto BK1Number = Number<BK1Value>{};
201 
202  static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
203  static constexpr bool is_single_rate_mfma = false;
204  static constexpr auto is_scale_mfma = true;
205 
206  static constexpr index_t NumDTensor = DsDataType::Size();
207 
208  static constexpr auto MXdlPack = 2;
209  static constexpr auto NXdlPack = 2;
210  static constexpr auto KXdlPack = 2;
211 
212  //> KPack is at least the k_per_blk of selected mfma
213  //
214  // Should be a multiple of k_per_blk.
215  // TODO: Move this to blockwise pipeline base
216  // KPack in packed data types for pk A/B
217 
218  static constexpr index_t APackedSize = packed_size_v<ADataType>;
219  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
220 
221  using mfma_selector = MfmaSelector<ComputeTypeA,
222  MPerXdl,
223  NPerXdl,
224  ComputeTypeB,
226  is_scale_mfma>;
227  static constexpr index_t KPack =
229 
230  // static constexpr index_t NumTokens = 1;
231  static constexpr index_t SortedTileSize = MPerBlock;
232 
233  static constexpr auto MakeDsGridPointer()
234  {
235  return generate_tuple(
236  [&](auto i) {
237  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
238 
239  return static_cast<const DDataType*>(nullptr);
240  },
242  }
243 
244  using DsGridPointer = decltype(MakeDsGridPointer());
245 
247 
248  __host__ static auto CalculateGridSize(index_t M, index_t N)
249  {
250  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
251  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
252  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
253  const index_t gridy = NSwizzle ? 1 : mblock;
254 
255  return std::make_tuple(gridx, gridy, 1);
256  }
257 
258  __host__ static auto CalculateMPadded(index_t M)
259  {
260  return math::integer_least_multiple(M, MPerBlock);
261  }
262 
263  __host__ static auto CalculateNPadded(index_t N)
264  {
265  return math::integer_least_multiple(N, NPerBlock);
266  }
267 
268  __host__ static auto CalculateKPadded(index_t K)
269  {
270  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
271  }
272 
273  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
274  {
275  auto K_t = K_Batch * KPerBlock;
276  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
277  }
278 
279  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
280  {
281  auto K_t = K_Batch * KPerBlock;
282  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
283  }
284 
285  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
286  {
287  auto K_t = K_Batch * KPerBlock;
288  return (K + K_t - 1) / K_t * KPerBlock;
289  }
290 
291  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
292  {
293  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
294  auto K_t = K_Batch * KReadVec;
295  return (K + K_t - 1) / K_t * KReadVec;
296  }
297 
298  __host__ static auto CalculateMBlock(index_t M)
299  {
300  return math::integer_divide_ceil(M, MPerBlock);
301  }
302 
303  __host__ static auto CalculateNBlock(index_t N)
304  {
305  return math::integer_divide_ceil(N, NPerBlock);
306  }
307 
308  template <index_t MNXdlPerWave,
309  index_t MNWaves,
310  index_t MNXdlPack,
311  index_t MNPerXdl,
312  typename TileDesc_K0_MN_K1>
313  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
314  {
315  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
316  constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
317  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
318 
319  constexpr auto permuted_desc = transform_tensor_descriptor(
320  TileDesc_K0_MN_K1{},
325 
327  permuted_desc,
330  Number<MNWaves>{},
332  Number<MNPerXdl>{}))),
335  }
336 
337  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
338  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
339  {
340  const auto a_grid_desc_mraw_kraw = [&]() {
341  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
342  {
343  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
344  }
345  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
346  {
347  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
348  }
349  }();
350 
352 
353  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
354  GemmSpec == GemmSpecialization::MNKPadding)
355  {
356  // pad both M and K
357  const auto a_grid_desc_m_k =
358  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
360  make_right_pad_transform(K, KPad - K)),
363 
364  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
365  a_grid_desc_m_k,
370 
371  return a_grid_desc_ak0_m_ak1;
372  }
373  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
374  GemmSpec == GemmSpecialization::MNPadding)
375  {
376  // pad M, but not K
377  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
378  a_grid_desc_mraw_kraw,
379  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
380  make_right_pad_transform(M, MPad - M)),
383 
384  const auto a_grid_desc_permuted = transform_tensor_descriptor(
385  a_grid_desc_ak0_m_ak1,
388  make_pass_through_transform(AK1Value)),
391 
392  const auto a_grid_desc = transform_tensor_descriptor(
393  a_grid_desc_permuted,
394  make_tuple(
397  make_pass_through_transform(AK1Value)),
400  return a_grid_desc;
401  }
402  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
403  GemmSpec == GemmSpecialization::NKPadding)
404  {
405  // pad K, but not M
406  const auto a_grid_desc_m_k = transform_tensor_descriptor(
407  a_grid_desc_mraw_kraw,
411 
412  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
413  a_grid_desc_m_k,
418 
419  return a_grid_desc_ak0_m_ak1;
420  }
421  else
422  {
423  // not pad M or K
424  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
425  a_grid_desc_mraw_kraw,
426  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
430 
431  const auto a_grid_desc_permuted = transform_tensor_descriptor(
432  a_grid_desc_ak0_m_ak1,
435  make_pass_through_transform(AK1Value)),
438 
439  const auto a_grid_desc = transform_tensor_descriptor(
440  a_grid_desc_permuted,
441  make_tuple(
444  make_pass_through_transform(AK1Value)),
447 
448  return a_grid_desc;
449  }
450  }
451 
452  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
453  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
454  {
455  const auto b_grid_desc_nraw_kraw = [&]() {
457  {
458  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
459  }
461  {
462  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
463  }
464  }();
465 
467 
468  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
469  GemmSpec != GemmSpecialization::Default),
470  "pk_i4_t does not support padding");
471  static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
472  (GemmSpec != GemmSpecialization::Default &&
473  GemmSpec != GemmSpecialization::MPadding)),
474  "f4x2_pk_t does not support K padding");
475 
476  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
477  GemmSpec == GemmSpecialization::MNKPadding)
478  {
479  // pad both N and K
480  const auto b_grid_desc_n_k =
481  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
483  make_right_pad_transform(K, KPad - K)),
486 
487  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
488  b_grid_desc_n_k,
493 
494  return b_grid_desc_bk0_n_bk1;
495  }
496  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
497  GemmSpec == GemmSpecialization::MNPadding)
498  {
499  // pad N, but not K
500  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
501  b_grid_desc_nraw_kraw,
503  make_right_pad_transform(N, NPad - N)),
506 
507  return b_grid_desc_bk0_n_bk1;
508  }
509  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
510  GemmSpec == GemmSpecialization::MKPadding)
511  {
512  // pad K, but not N
513  const auto b_grid_desc_n_k = transform_tensor_descriptor(
514  b_grid_desc_nraw_kraw,
518 
519  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
520  b_grid_desc_n_k,
525 
526  return b_grid_desc_bk0_n_bk1;
527  }
528  else
529  {
530  // not pad N or K
531  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
532  b_grid_desc_nraw_kraw,
533  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
537 
538  const auto b_grid_desc_permuted = transform_tensor_descriptor(
539  b_grid_desc_bk0_n_bk1,
542  make_pass_through_transform(BK1Value)),
545 
546  const auto b_grid_desc = transform_tensor_descriptor(
547  b_grid_desc_permuted,
548  make_tuple(
551  make_pass_through_transform(BK1Value)),
554 
555  return b_grid_desc;
556  }
557  }
558 
559  template <typename ABlockDesc_AK0_M_AK1>
560  __host__ __device__ static constexpr auto
561  MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
562  {
563  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
564 
565  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl>(
566  ABlockDesc_AK0_M_AK1{});
567  }
568 
569  template <typename BBlockDesc_BK0_N_BK1>
570  __host__ __device__ static constexpr auto
571  MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
572  {
573  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
574 
575  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl>(
576  BBlockDesc_BK0_N_BK1{});
577  }
578 
579  template <typename ELayout>
580  __host__ __device__ static auto MakeCGridDescriptor_M_N(
581  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
582  {
583  const auto c_grid_desc_mraw_nraw = [&]() {
585  {
586  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
587  }
589  {
590  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
591  }
592  }();
593 
594  // pad M and N
595  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
597  make_right_pad_transform(N, NPad - N)),
600  }
601 
602  template <typename DLayout>
603  __host__ __device__ static auto
605  {
606  const auto c_grid_desc_mraw_nraw = [&]() {
608  {
609  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
610  }
612  {
613  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
614  }
615  }();
616 
617  // pad M and N
618  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
620  make_right_pad_transform(N, NPad - N)),
623  }
624 
625  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
626  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
627  {
628  return generate_tuple(
629  [&](auto i) {
630  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
631  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
632  },
634  }
635 
636  template <typename DsGridDesc>
638  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
639  {
640  return generate_tuple(
641  [&](auto i) {
643  ds_grid_desc_m_n[i], MBlock, NBlock);
644  },
646  }
647 
648  struct Problem
649  {
650  __host__ Problem(index_t NumTokens_,
651  index_t TopK_,
652  index_t M_,
653  index_t N_,
654  index_t K_,
655  index_t StrideA_,
656  index_t StrideScaleA_,
657  index_t StrideB_,
658  index_t StrideScaleB_,
659  std::array<index_t, NumDTensor> StrideDs_,
660  index_t StrideC_,
661  index_t KBatch_)
662  : NumTokens{NumTokens_},
663  TopK{TopK_},
664  M{M_},
665  N{N_},
666  K{K_},
667  StrideA{StrideA_},
668  StrideScaleA{StrideScaleA_},
669  StrideB{StrideB_},
670  StrideScaleB{StrideScaleB_},
671  StrideDs{StrideDs_},
672  StrideC{StrideC_},
673  KBatch{KBatch_},
676  KRead{CalculateKRead(K_, KBatch_)},
677  KPadded{CalculateKPadded(K_, KBatch_)},
678  AK0{CalculateAK0Padded(K_, KBatch_)},
679  BK0{CalculateBK0Padded(K_, KBatch_)},
680  MBlock{CalculateMBlock(M_)},
682  {
683  }
684 
685  __host__ void Print() const
686  {
687  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
688  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
689  << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
690  << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
691  << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
692  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
693  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
694  << ", " << "NBlock: " << NBlock << "}" << std::endl;
695  }
696 
706  std::array<index_t, NumDTensor> StrideDs;
717  };
718 
719  // Argument
721  {
722  __host__ Argument(const index_t* p_sorted_token_ids_,
723  const index_t* p_sorted_expert_ids_,
724  const index_t* p_max_token_id_,
725  const ADataType* p_a_grid_,
726  const AScaleDataType* p_a_scale_grid_,
727  const BDataType* p_b_grid_,
728  const BScaleDataType* p_b_scale_grid_,
729  std::array<const void*, NumDTensor> p_ds_grid_,
730  CDataType* p_c_grid_,
731  index_t NumTokens_,
732  index_t TopK_,
733  index_t M_,
734  index_t N_,
735  index_t K_,
736  index_t StrideA_,
737  index_t StrideScaleA_,
738  index_t StrideB_,
739  index_t StrideScaleB_,
740  std::array<index_t, NumDTensor> StrideDs_,
741  index_t StrideC_,
742  index_t k_batch_,
743  AElementwiseOperation a_element_op_,
744  BElementwiseOperation b_element_op_,
745  CElementwiseOperation c_element_op_)
746  : Problem{NumTokens_,
747  TopK_,
748  M_,
749  N_,
750  K_ / APackedSize,
751  StrideA_ / APackedSize,
752  StrideScaleA_,
753  StrideB_ / BPackedSize,
754  StrideScaleB_,
755  StrideDs_,
756  StrideC_,
757  k_batch_},
758  p_sorted_token_ids{p_sorted_token_ids_},
759  p_sorted_expert_ids{p_sorted_expert_ids_},
760  p_max_token_id{p_max_token_id_},
761  p_a_grid{p_a_grid_},
762  p_a_scale_grid{p_a_scale_grid_},
763  p_b_grid{p_b_grid_},
764  p_b_scale_grid{p_b_scale_grid_},
765  p_ds_grid{},
766  p_c_grid{p_c_grid_},
767  a_element_op{a_element_op_},
768  b_element_op{b_element_op_},
769  c_element_op{c_element_op_}
770  {
771 
772  // populate pointer, desc for Ds
773  static_for<0, NumDTensor, 1>{}([&](auto i) {
774  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
775 
776  // D pointer
777  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
778  });
779  }
780 
784  const ADataType* p_a_grid;
785  const AScaleDataType* p_a_scale_grid;
786  const BDataType* p_b_grid;
787  const BScaleDataType* p_b_scale_grid;
789  CDataType* p_c_grid;
790 
791  const AElementwiseOperation a_element_op;
792  const BElementwiseOperation b_element_op;
793  const CElementwiseOperation c_element_op;
794  };
795 
797  {
798  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
799  {
800  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
801  {
802  a_k_split_offset = k_id * karg.KRead;
803  }
804  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
805  {
806  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
807  }
808 
809  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
810  {
811  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
812  }
813  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
814  {
815  // KPack * NLane * KLane * K0 * N0
816  b_k_split_offset = k_id * karg.KRead;
817  }
818 
819  // Calculate A scale offset
820  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
821  {
822  a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
823  }
824  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
825  {
827  k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
828  }
829 
830  // Calculate B scale offset
831  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
832  {
834  k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
835  }
836  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
837  {
838  b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
839  }
840 
841  if(k_id < karg.KBatch - 1)
842  {
843  karg.K = karg.KRead;
844  }
845  else
846  {
847  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
848  }
849  }
850 
855  };
856 
857  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
858  {
859  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
860  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
861  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
862 
863  // A matrix in LDS memory, dst of blockwise copy
864  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
865  {
866  // contiguous in LDS
870  }
871  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
872  // in some cases.
874  {
875  constexpr auto a_lds_block_desc =
878 
879  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
880  a_lds_block_desc,
886 
887  return a_lds_block_desc_permuted;
888  }
889  else // ColumnMajor A
890  {
891  // kfold and mpair dimension is not always required.
892  // more dimension in merge_transform increase the difficulty of generating immarg offset
893  // for compiler.
894  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
895  constexpr auto M1 = MPerBlock / M0;
896 
897  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
898  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
899  constexpr auto KThreadRead = WaveSize / MPerXdl;
900  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
901 
902  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
903  ? 1
904  : 128 / (AK1Number * M0 * sizeof(ADataType));
905  constexpr auto KThreadReadPerm =
906  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
907  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
908  : KThreadRead;
909 
910  // 1<=mpair<=n0
911  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
912  ? 1
913  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
914  ? M0
915  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
916 
917  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
921  Number<kfold * M0 / mpair>{},
922  Number<mpair>{},
923  AK1Number));
924 
925  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
926  a_lds_block_desc,
927  make_tuple(
931  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
934  make_tuple(
936  make_tuple(
938 
939  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
940  a_lds_block_desc_permuted,
941  make_tuple(
949  Sequence<1>{},
950  Sequence<2>{},
951  Sequence<3>{},
952  Sequence<4>{},
953  Sequence<5>{}),
955  Sequence<2>{},
956  Sequence<0, 3>{},
957  Sequence<4, 5>{},
958  Sequence<6>{},
959  Sequence<7>{}));
960 
961  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
962  a_lds_block_desc_unmerged,
965  Number<KThreadWrite / kfold / KThreadReadPerm>{},
966  Number<kfold>{},
973 
974  return a_lds_block_desc_ak0_m_ak1;
975  }
976  }
977 
978  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
979  {
980  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
981  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
982  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
983 
984  // B matrix in LDS memory, dst of blockwise copy
985  if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
986  {
987  // contiguous in lds
991  }
993  {
994  // NLdsLayer * K0 as logical Bank
995  constexpr auto b_lds_block_desc =
998 
999  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1000  b_lds_block_desc,
1006 
1007  return b_lds_block_desc_permuted;
1008  }
1009  else // RowMajor B
1010  {
1011  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
1012  constexpr auto N1 = NPerBlock / N0;
1013 
1014  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
1015  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
1016  constexpr auto KThreadRead = WaveSize / NPerXdl;
1017  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
1018 
1019  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
1020  ? 1
1021  : 128 / (BK1Number * N0 * sizeof(BDataType));
1022  constexpr auto KThreadReadPerm =
1023  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1024  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1025  : KThreadRead;
1026 
1027  // 1<=npair<=n0
1028  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
1029  ? 1
1030  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
1031  ? N0
1032  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
1033 
1034  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
1038  Number<kfold * N0 / npair>{},
1039  Number<npair>{},
1040  BK1Number));
1041 
1042  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1043  b_lds_block_desc,
1044  make_tuple(
1048  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
1051  make_tuple(
1053  make_tuple(
1055 
1056  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
1057  b_lds_block_desc_permuted,
1058  make_tuple(
1066  Sequence<1>{},
1067  Sequence<2>{},
1068  Sequence<3>{},
1069  Sequence<4>{},
1070  Sequence<5>{}),
1072  Sequence<2>{},
1073  Sequence<0, 3>{},
1074  Sequence<4, 5>{},
1075  Sequence<6>{},
1076  Sequence<7>{}));
1077 
1078  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1079  b_lds_block_desc_unmerged,
1082  Number<KThreadWrite / kfold / KThreadReadPerm>{},
1083  Number<kfold>{},
1090 
1091  return b_lds_block_desc_bk0_n_bk1;
1092  }
1093  }
1094 
1096  {
1097  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1098  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1099 
1100  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1102  make_tuple(I1,
1104  I1,
1106 
1107  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1108  }
1109 
1112  BlkGemmPipelineVer,
1113  BlkGemmPipeSched,
1114  BlockSize,
1115  ScaleBlockSize,
1116  ADataType,
1117  AScaleDataType,
1118  BDataType,
1119  BScaleDataType,
1120  ComputeTypeA,
1121  AccDataType,
1128  ABlockTransferSrcScalarPerVector,
1129  BBlockTransferSrcScalarPerVector,
1130  MPerBlock,
1131  NPerBlock,
1132  KPerBlock,
1133  MPerXdl,
1134  NPerXdl,
1135  MXdlPerWave,
1136  NXdlPerWave,
1137  KPack,
1138  IsInputGemm>())>;
1139 
1140  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1141  {
1142  // LDS allocation for A and B: be careful of alignment
1143  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1144  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1145 
1146  // lds max alignment
1147  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1148 
1149  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1150  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1151 
1152  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1153  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1154 
1155  // LDS allocation for C shuffle in LDS
1156  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1158 
1159  constexpr auto c_block_size =
1160  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1161 
1162  if constexpr(IsInputGemm)
1163  {
1164  return math::max(a_block_space_size_aligned * sizeof(ADataType) +
1165  b_block_space_size_aligned * sizeof(BDataType) * 2,
1166  c_block_size * sizeof(CShuffleDataType));
1167  }
1168  else
1169  {
1170  return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1171  b_block_space_size_aligned * sizeof(BDataType)),
1172  c_block_size * sizeof(CShuffleDataType));
1173  }
1174  }
1175 
1177 
1178  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1179  __host__ static constexpr bool CheckValidity(const Argument& karg)
1180  {
1181  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1182  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1183  "Invalid tuning param!");
1184 
1185  static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1186  "KPerBlock should be multiple of ScaleBlockSize");
1187 
1188  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1193  {
1194  if(!(karg.M % MPerBlock == 0))
1195  {
1196  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1197  {
1198  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1199  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1200  << std::endl;
1201  }
1202  return false;
1203  }
1204  }
1205 
1206  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1211  {
1212  if(!(karg.N % NPerBlock == 0))
1213  {
1214  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1215  {
1216  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1217  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1218  << std::endl;
1219  }
1220  return false;
1221  }
1222  }
1223 
1224  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1228  {
1229  auto K_t = karg.KBatch * KPerBlock;
1230  if(!(karg.K % K_t == 0))
1231  {
1232  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1233  {
1234  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1235  << karg.K << " " << __FILE__ << ":" << __LINE__
1236  << ", in function: " << __func__ << std::endl;
1237  }
1238  return false;
1239  }
1240  }
1241  else
1242  {
1243  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1244  auto K_t = karg.KBatch * KReadVec;
1245  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1246  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1247  {
1248  return false;
1249  }
1250  }
1251 
1253  {
1254  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1255  {
1256  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1257  {
1258  std::cout << "Arg K (" << karg.K
1259  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1260  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1261  << __LINE__ << ", in function: " << __func__ << std::endl;
1262  }
1263  return false;
1264  }
1265  }
1266  else
1267  {
1268  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1269  {
1270  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1271  {
1272  std::cout << "Arg M (" << karg.M
1273  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1274  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1275  << __LINE__ << ", in function: " << __func__ << std::endl;
1276  }
1277  return false;
1278  }
1279  }
1280 
1282  {
1283  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1284  {
1285  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1286  {
1287  std::cout << "Arg N (" << karg.N
1288  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1289  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1290  << __LINE__ << ", in function: " << __func__ << std::endl;
1291  }
1292  return false;
1293  }
1294  }
1295  else
1296  {
1297  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1298  {
1299  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1300  {
1301  std::cout << "Arg K (" << karg.K
1302  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1303  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1304  << __LINE__ << ", in function: " << __func__ << std::endl;
1305  }
1306  return false;
1307  }
1308  }
1309 
1311  {
1313  {
1314  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1315  {
1316  std::cout << "Arg N (" << karg.N
1317  << ") value is not a multiple of "
1318  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1320  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1321  << std::endl;
1322  }
1323  return false;
1324  }
1325  }
1326  else
1327  {
1329  {
1330  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1331  {
1332  std::cout << "Arg M (" << karg.M
1333  << ") value is not a multiple of "
1334  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1336  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1337  << std::endl;
1338 
1339  return false;
1340  }
1341  }
1342  }
1343 
1344  // check gridwise gemm pipeline
1345 #if 0
1346  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1347 
1348  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1349  {
1350  return false;
1351  }
1352 #endif
1353  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1354  return true;
1355  }
1356 
1357  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1358  {
1359  const index_t num_loop = K / KPerBlock;
1360 
1361  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1362  }
1363 
1364  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1365  {
1366  const index_t num_loop = K / KPerBlock;
1367 
1368  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1369  }
1370 
1371  template <typename CGridDesc>
1372  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1373  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1374  {
1375  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1376  c_grid_desc_m_n,
1381 
1382  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1383  }
1384 
1385  // return block_id to C matrix tile idx (m0, n0) mapping
1386  // if arch = gfx942
1387  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1388  // NPerBlock>;
1389 
1391  static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
1392  static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
1393  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
1394  "A scale pack data type too large!");
1395  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
1396  "B scale pack data type too large!");
1397 
1398  static_assert(is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
1399  is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>,
1400  "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!");
1401 
1402 #if 0
1403  template <bool HasMainKBlockLoop,
1404  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1405  TailNumber TailNum = TailNumber::Odd>
1406  __device__ static void Run(const index_t* p_sorted_token_ids,
1407  const index_t* p_sorted_expert_ids,
1408  const index_t* p_max_token_id,
1409  const ADataType* p_a_grid,
1410  const AScaleDataType* p_a_scale_grid,
1411  const BDataType* p_b_grid,
1412  const BScaleDataType* p_b_scale_grid,
1413  DsGridPointer& p_ds_grid,
1414  CDataType* p_c_grid,
1415  void* p_shared,
1416  const Problem& problem,
1417  AElementwiseOperation a_element_op,
1418  BElementwiseOperation b_element_op,
1419  CElementwiseOperation c_element_op)
1420  {
1421  ignore = a_element_op;
1422  ignore = b_element_op;
1423  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1424  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1425  problem.MPadded,
1426  problem.K,
1427  problem.KPadded,
1428  problem.StrideA,
1429  problem.AK0);
1430  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1431  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1432  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1433  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1434  problem.MPadded,
1435  problem.N,
1436  problem.NPadded,
1437  problem.StrideC);
1438 
1439  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1440  make_tuple(problem.M / (MXdlPack * MPerXdl),
1441  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1442  (KXdlPack * 64 / MPerXdl),
1443  64 * KXdlPack * MXdlPack / scale_pack_size_a));
1444 
1445  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1446  make_tuple(problem.N / (NXdlPack * NPerXdl),
1447  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1448  (KXdlPack * 64 / NPerXdl),
1449  64 * KXdlPack * NXdlPack / scale_pack_size_b));
1450 
1451  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1453  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1454 
1455  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1456  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1457  if(expert_block_id * MPerBlock >= max_token_id)
1458  return;
1459  const index_t expert_id =
1460  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1461 
1462  const auto block_mn = [&]() -> std::pair<int, int> {
1463  if constexpr(NSwizzle)
1464  {
1465  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1466  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1467  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1468  const index_t expert_swizzle =
1469  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1470  const index_t bid_new = blockIdx.x - prefix_block;
1471  const index_t nid = __builtin_amdgcn_readfirstlane(
1472  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1473  const index_t mid =
1474  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1475  return {nid, mid};
1476  }
1477  else
1478  {
1479  return {blockIdx.x, blockIdx.y};
1480  }
1481  }();
1482 
1483  const index_t block_n_id = block_mn.first;
1484  const index_t block_m_id = block_mn.second;
1485  const index_t token0 =
1486  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1487 
1488  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1489  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1490  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1491  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1492  constexpr auto AKThreads = AK0Threads * AK1Threads;
1493  constexpr auto AMRepeats = MPerBlock / AMThreads;
1494  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1495 
1496  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1497  return;
1498  StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
1499  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1500  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1501  index_t token_offset = fused_token & 0xffffff;
1502  if constexpr(!IsInputGemm)
1503  {
1504  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1505  }
1506  gather_offsets(m0) = static_cast<IndexType>(token_offset);
1507  });
1508 
1509  const index_t expert_stride =
1510  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1511  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1512  problem.N * (IsInputGemm ? 2 : 1) *
1513  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
1514 
1515  // N0, K0, Blocksize*KPack
1516  const index_t n_block_data_idx_on_grid =
1517  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1518 
1519  // Gride buffer creation
1520  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1521  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1522  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1523  p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1524 
1525  // A, B scale buffer
1526  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1527  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1528  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1529  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
1530  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1531 
1532  // lds max alignment
1533  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1534 
1535  // A matrix in LDS memory, dst of blockwise copy
1536  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1537 
1538  // B matrix in LDS memory, dst of blockwise copy
1539  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1540 
1541  // A matrix blockwise direct to LDS copy
1542  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
1544  Sequence<AK0Number, MPerBlock, AK1Number>,
1545  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1546  ABlockTransferThreadClusterArrangeOrder,
1547  ADataType,
1548  ADataType,
1549  decltype(a_grid_desc_ak0_m_ak1),
1550  decltype(a_block_desc_ak0_m_ak1),
1551  ABlockTransferSrcAccessOrder,
1552  ABlockTransferSrcVectorDim,
1553  2,
1554  ABlockTransferSrcScalarPerVector,
1555  IndexType,
1556  1>(a_grid_desc_ak0_m_ak1,
1557  make_multi_index(0, 0, 0),
1558  a_block_desc_ak0_m_ak1,
1559  make_multi_index(0, 0, 0),
1560  gather_offsets);
1561 
1562  // B matrix blockwise copy
1563  auto b_blockwise_copy =
1564  ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
1565  Sequence<BK0Number, NPerBlock, BK1Number>,
1566  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1567  BBlockTransferThreadClusterArrangeOrder,
1568  BDataType,
1569  BDataType,
1570  decltype(b_grid_desc_bk0_n_bk1),
1571  decltype(b_block_desc_bk0_n_bk1),
1572  BBlockTransferSrcAccessOrder,
1573  BBlockTransferSrcVectorDim,
1574  2,
1575  BBlockTransferSrcScalarPerVector>(
1576  b_grid_desc_bk0_n_bk1,
1577  make_multi_index(0, n_block_data_idx_on_grid, 0),
1578  b_block_desc_bk0_n_bk1,
1579  make_multi_index(0, 0, 0));
1580 
1581  // LDS allocation for A and B: be careful of alignment
1582  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1583  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1584 
1585  // Cast after lds
1586  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1587  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1588 
1589  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1590  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1591  a_block_space_size_aligned * sizeof(ADataType)),
1592  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1593 
1594  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1595  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1596 
1597  // Blockwise GEMM pipeline
1598  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1599  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1600  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1601  decltype(c_thread_buf) c_thread_buf_up;
1602 
1603  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
1604  float,
1605  c_thread_buf.num_of_v_,
1606  c_thread_buf.s_per_v,
1607  true>
1608  c_thread_buf_fp32;
1609 
1610  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1611  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1612  KPerBlock);
1613 
1614  // a and b scale processing
1615  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1616  const auto waveId_m = wave_idx[I0];
1617  const auto waveId_n = wave_idx[I1];
1618 
1619  auto thread_offset_shuffled =
1620  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1621 
1622  auto a_thread_offset_m = waveId_m;
1623 
1624  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1625  AScaleDataType,
1626  AScaleDataType,
1627  decltype(a_scale_grid_desc_am_ak),
1628  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1629  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1630  Sequence<0, 1, 2>, // DimAccessOrder
1631  2, // SrcVectorDim
1632  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1633  1, // SrcScalarStrideInVector
1634  true>(a_scale_grid_desc_am_ak,
1635  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1636  0,
1637  thread_offset_shuffled / scale_pack_size_a));
1638 
1639  // B scale load
1640  auto b_thread_offset_n = waveId_n;
1641 
1642  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1643  BScaleDataType,
1644  BScaleDataType,
1645  decltype(b_scale_grid_desc_bn_ak),
1646  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1647  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1648  Sequence<0, 1, 2>, // DimAccessOrder
1649  2, // SrcVectorDim
1650  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1651  1, // SrcScalarStrideInVector
1652  true>(b_scale_grid_desc_bn_ak,
1653  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1654  0,
1655  thread_offset_shuffled / scale_pack_size_b));
1656 
1657  if constexpr(IsInputGemm)
1658  {
1659  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1660  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1661  auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1662  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1663  a_block_space_size_aligned * sizeof(ADataType) +
1664  b_block_space_size_aligned * sizeof(BDataType)),
1665  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1666 
1667  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1668  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1669  p_b_grid_up + expert_id * expert_stride,
1670  b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1671 
1672  auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
1674  Sequence<BK0Number, NPerBlock, BK1Number>,
1675  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1676  BBlockTransferThreadClusterArrangeOrder,
1677  BDataType,
1678  BDataType,
1679  decltype(b_grid_desc_bk0_n_bk1),
1680  decltype(b_block_desc_bk0_n_bk1),
1681  BBlockTransferSrcAccessOrder,
1682  BBlockTransferSrcVectorDim,
1683  2,
1684  BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
1685  make_multi_index(0, n_block_data_idx_on_grid, 0),
1686  b_block_desc_bk0_n_bk1,
1687  make_multi_index(0, 0, 0));
1688 
1689  const BScaleDataType* p_b_scale_grid_up =
1690  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
1691  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1692  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
1693  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1694 
1695  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1696  BScaleDataType,
1697  BScaleDataType,
1698  decltype(b_scale_grid_desc_bn_ak),
1699  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1700  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1701  Sequence<0, 1, 2>, // DimAccessOrder
1702  2, // SrcVectorDim
1703  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1704  1, // SrcScalarStrideInVector
1705  true>(
1706  b_scale_grid_desc_bn_ak,
1707  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1708  0,
1709  thread_offset_shuffled / scale_pack_size_b));
1710 
1711  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1712  // A
1713  a_grid_desc_ak0_m_ak1,
1714  a_block_desc_ak0_m_ak1,
1715  a_blockwise_copy,
1716  a_grid_buf,
1717  a_block_buf,
1718  a_block_slice_copy_step,
1719  // Gate and Up
1720  b_grid_desc_bk0_n_bk1,
1721  b_block_desc_bk0_n_bk1,
1722  b_blockwise_copy,
1723  b_blockwise_copy_up,
1724  b_grid_buf,
1725  b_grid_buf_up,
1726  b_block_buf,
1727  b_block_buf_up,
1728  b_block_slice_copy_step,
1729  // C
1730  c_thread_buf,
1731  c_thread_buf_up,
1732  // A scale
1733  a_scale_grid_desc_am_ak,
1734  a_scale_thread_copy,
1735  a_scale_grid_buf,
1736  // Gate and Up scale
1737  b_scale_grid_desc_bn_ak,
1738  b_scale_thread_copy,
1739  b_scale_thread_copy_up,
1740  b_scale_grid_buf,
1741  b_scale_grid_buf_up,
1742  num_k_block_main_loop);
1743  }
1744  else
1745  {
1746  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1747  a_grid_desc_ak0_m_ak1, // A
1748  a_block_desc_ak0_m_ak1,
1749  a_blockwise_copy,
1750  a_grid_buf,
1751  a_block_buf,
1752  a_block_slice_copy_step,
1753  b_grid_desc_bk0_n_bk1, // B
1754  b_block_desc_bk0_n_bk1,
1755  b_blockwise_copy,
1756  b_grid_buf,
1757  b_block_buf,
1758  b_block_slice_copy_step,
1759  c_thread_buf, // C
1760  a_scale_grid_desc_am_ak, // A scale
1761  a_scale_thread_copy,
1762  a_scale_grid_buf,
1763  b_scale_grid_desc_bn_ak, // B scale
1764  b_scale_thread_copy,
1765  b_scale_grid_buf,
1766  num_k_block_main_loop);
1767  }
1768 
1769  // shuffle C and write out
1770  {
1771  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1772  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1773  "wrong!");
1774  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1775  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1776  "wrong!");
1777 
1778  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1779  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1780 
1781  // TODO: hacky, fix it!
1782  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1783  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1784 
1785  // TODO: hacky, fix it!
1786  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1787  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1788  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1789 
1790  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1791  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1792  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1793  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1794  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1795  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1796  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1797  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1798  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1799  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1800 
1801  // mul scales
1802  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1803  static_assert(M5 == 4);
1804  const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id
1805  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
1806 
1807  vector_type<float, 4> topk_weights; // for gemm2 only
1808  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
1809  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
1810  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
1811  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
1812  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
1813  const index_t m_pos = block_m_id * MPerBlock +
1814  m0 * M2 * M1 * M3 * M4 * M5 +
1815  m1 * M2 * M3 * M4 * M5 +
1816  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1817  if constexpr(MulRoutedWeight)
1818  {
1819  topk_weights =
1820  *c_style_pointer_cast<const vector_type<float, M5>*>(
1821  p_ds_grid[I2] + m_pos);
1822  }
1823  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
1824  constexpr index_t c_offset =
1825  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1826  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1827  constexpr auto cidx = Number<c_offset>{};
1828 
1829  if constexpr(IsInputGemm) // gu fusion
1830  {
1831  if constexpr(ActivationOperation ==
1833  {
1834  float gate = c_thread_buf[cidx];
1835  float up = c_thread_buf_up[cidx];
1836  if constexpr(MulRoutedWeight)
1837  {
1838  gate = gate * topk_weights.AsType<float>()[m5];
1839  up = up * topk_weights.AsType<float>()[m5];
1840  }
1841  tensor_operation::element_wise::Silu{}(gate, gate);
1842  c_thread_buf_fp32(cidx) = gate * up;
1843  }
1844  else if(ActivationOperation == Activation::gelu_and_mul)
1845  {
1846  float gate = c_thread_buf[cidx];
1847  float up = c_thread_buf_up[cidx];
1848  if constexpr(MulRoutedWeight)
1849  {
1850  gate = gate * topk_weights.AsType<float>()[m5];
1851  up = up * topk_weights.AsType<float>()[m5];
1852  }
1853  tensor_operation::element_wise::Gelu{}(gate, gate);
1854  c_thread_buf_fp32(cidx) = gate * up;
1855 
1856  /*float gate = c_thread_buf[cidx];
1857  float up = c_thread_buf_up[cidx];
1858  if constexpr(MulRoutedWeight)
1859  {
1860  gate = gate * topk_weights.AsType<float>()[m5];
1861  //up = up * topk_weights.AsType<float>()[m5];
1862  }
1863  tensor_operation::element_wise::Gelu{}(gate, gate);
1864  c_thread_buf_fp32(cidx) = up;*/
1865  }
1866  }
1867  else
1868  {
1869  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1870  if constexpr(MulRoutedWeight)
1871  {
1872  c_thread_buf_fp32(cidx) =
1873  topk_weights.AsType<float>()[m5] *
1874  c_thread_buf_fp32[cidx];
1875  }
1876  }
1877  });
1878  });
1879  });
1880  });
1881  });
1882  });
1883 
1884  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1886 
1887  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1888  static_cast<CShuffleDataType*>(p_shared),
1889  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1890 
1891  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1892  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1893  make_tuple(
1896  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave)
1897  // per shuffle
1898  M1, // M1 = MWave
1899  M2, // M2 = MXdlPack
1900  M3, // M3 * M4 * M5 = MPerXdl
1901  M4,
1902  M5)),
1905  Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave)
1906  // per shuffle
1907  N1, // N1 = NWave
1908  N2, // N2 = NXdlPack
1909  N3))), // N3 = NPerXdl
1910  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
1911  make_tuple(Sequence<>{},
1912  Sequence<0, 2, 4, 6, 7, 8>{},
1913  Sequence<>{},
1914  Sequence<1, 3, 5, 9>{}));
1915 
1916  // calculate origin of thread output tensor on global memory
1917  // blockwise GEMM c matrix starting index
1918  const auto c_thread_mtx_on_block =
1919  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1920 
1921  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1922  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1923 
1924  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1926  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1927  make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}),
1928  make_tuple(Sequence<0>{}));
1929 
1930  const auto m_thread_data_on_block_idx =
1931  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1932  make_multi_index(m_thread_data_on_block));
1933 
1934  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1936  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1937  make_tuple(Sequence<0, 1, 2, 3>{}),
1938  make_tuple(Sequence<0>{}));
1939 
1940  const auto n_thread_data_on_block_idx =
1941  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1942  make_multi_index(n_thread_data_on_block));
1943 
1944  // shuffle: threadwise copy C from VGPR to LDS
1945  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1946  AccDataType,
1947  CShuffleDataType,
1948  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1949  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1951  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1952  CShuffleNXdlPerWavePerShuffle / NXdlPack,
1953  I1,
1954  I1,
1955  M2,
1956  N2,
1957  M3,
1958  I1,
1959  M5,
1960  I1>,
1961  Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1962  9,
1963  1,
1965  1,
1966  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1967  make_multi_index(0,
1968  0,
1969  m_thread_data_on_block_idx[I1],
1970  n_thread_data_on_block_idx[I1],
1971  m_thread_data_on_block_idx[I2],
1972  n_thread_data_on_block_idx[I2],
1973  m_thread_data_on_block_idx[I3],
1974  m_thread_data_on_block_idx[I4],
1975  m_thread_data_on_block_idx[I5],
1976  n_thread_data_on_block_idx[I3]),
1978 
1979  using EDataType = CDataType;
1980 
1981  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1982  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1983 
1984  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1986  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1987 
1988  const auto ds_grid_buf = generate_tuple(
1989  [&](auto i) {
1990  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1991  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1992  },
1993  Number<NumDTensor>{});
1994 
1995  // tuple of reference to C/Ds tensor descriptors
1996  const auto c_ds_desc_refs = concat_tuple_of_reference(
1997  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1998  generate_tie([&](auto i) -> const auto& // return type should be reference
1999  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2000  Number<NumDTensor>{}));
2001 
2002  // tuple of reference to C/Ds tensor descriptors
2003  const auto c_ds_buf_refs = concat_tuple_of_reference(
2004  tie(c_shuffle_block_buf),
2005  generate_tie([&](auto i) -> const auto& // return type should be reference
2006  { return ds_grid_buf[i]; },
2007  Number<NumDTensor>{}));
2008 
2009  // tuple of starting index of C/Ds blockwise copy
2010  const auto idx_c_ds_block_begin =
2013  [&](auto) {
2014  return make_multi_index(block_m_id, 0, block_n_id, 0);
2015  // return make_multi_index(block_work_idx[I0], 0,
2016  // block_work_idx[I1], 0);
2017  },
2018  Number<NumDTensor>{}));
2019 
2020  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2021  c_grid_desc_mblock_mperblock_nblock_nperblock;
2022 
2023  using CDEBlockTransferCluster =
2024  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2025  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2026  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2027  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2029  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2030  Tuple<EDataType>,
2031  decltype(c_ds_desc_refs),
2032  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2033  CElementwiseOperation,
2034  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2035  // Sequence support
2036  // arbitray type
2037  Sequence<1,
2038  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2039  1,
2040  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2041  CDEBlockTransferCluster,
2042  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2043  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2044  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2045  3, // index_t SrcVectorDim,
2046  3, // index_t DstVectorDim,
2047  CDEShuffleBlockTransferScalarPerVectors,
2050  Sequence<true>,
2052  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2053  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2054  IndexType,
2055  1, // ScatterDim
2056  true, // OutputScatter: false, only use scatter weights
2057  scatter_weight_idx // ScatterWeightIdx: ascale
2058  >{c_ds_desc_refs,
2059  idx_c_ds_block_begin,
2060  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2061  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2062  c_element_op};
2063 
2064  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2065  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2066 
2067  constexpr auto sfc_c_vgpr =
2068  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2069  NXdlPerWave / NXdlPack,
2070  1,
2071  1,
2072  MXdlPack,
2073  NXdlPack,
2074  M2,
2075  1,
2076  M4,
2077  1>,
2078  Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2079  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2080  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2081  1,
2082  1,
2083  MXdlPack,
2084  NXdlPack,
2085  M2,
2086  1,
2087  M4,
2088  1>>{};
2089 
2090  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2091 
2092  // space filling curve for shuffled blockwise C/D/E
2093  constexpr auto sfc_cde_block =
2094  SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
2095  Sequence<0, 2, 1, 3>,
2096  Sequence<1,
2097  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2098  1,
2099  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2100 
2101  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2102  constexpr auto EMThreads =
2103  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2104  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2105  constexpr auto ENThreads =
2106  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2107  static_for<0, num_access, 1>{}([&](auto access_id) {
2108  // make sure it's safe to write to LDS
2109  StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
2110 
2111  auto dstidx = sfc_cde_block.GetIndex(access_id);
2112  const index_t c_token_pos =
2113  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2114  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2115  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2116  IndexType token_offset = fused_token & 0xffffff;
2117  if constexpr(IsInputGemm)
2118  {
2119  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2120  }
2121  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2122  });
2123 
2124  block_sync_lds();
2125 
2126  // each thread write its data from VGPR to LDS
2127  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2128  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2129  c_thread_buf_fp32,
2130  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2131  c_shuffle_block_buf);
2132 
2133  // make sure it's safe to read from LDS
2134  block_sync_lds();
2135 
2136  // each block copy its data from LDS to global
2137  cde_block_copy_lds_and_global.Run(
2138  c_ds_desc_refs,
2139  c_ds_buf_refs,
2140  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2141  tie(c_grid_buf),
2142  scatter_offsets);
2143 
2144  if constexpr(access_id < num_access - 1)
2145  {
2146  constexpr auto cde_lds_and_global_step =
2147  sfc_cde_block.GetForwardStep(access_id);
2148 
2149  // move on Ds
2150  static_for<0, NumDTensor, 1>{}([&](auto i) {
2151  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2152  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2153  });
2154 
2155  // move on E
2156  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2157  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2158  I0,
2159  cde_lds_and_global_step);
2160  }
2161  });
2162  }
2163  }
2164 #endif
2165 
2166  template <bool HasMainKBlockLoop,
2167  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2168  TailNumber TailNum = TailNumber::Odd>
2169  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
2170  const index_t* p_sorted_expert_ids,
2171  const index_t* p_max_token_id,
2172  const ADataType* p_a_grid,
2173  const AScaleDataType* p_a_scale_grid,
2174  const BDataType* p_b_grid,
2175  const BScaleDataType* p_b_scale_grid,
2176  DsGridPointer& p_ds_grid,
2177  CDataType* p_c_grid,
2178  void* p_shared_0,
2179  void* p_shared_1,
2180  const Problem& problem,
2181  AElementwiseOperation a_element_op,
2182  BElementwiseOperation b_element_op,
2183  CElementwiseOperation c_element_op)
2184  {
2185  ignore = a_element_op;
2186  ignore = b_element_op;
2187  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2188  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2189  problem.MPadded,
2190  problem.K,
2191  problem.KPadded,
2192  problem.StrideA,
2193  problem.AK0);
2194  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2195  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2196  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2197  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2198  problem.MPadded,
2199  problem.N,
2200  problem.NPadded,
2201  problem.StrideC);
2202 
2203  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
2204  make_tuple(problem.M / (MXdlPack * MPerXdl),
2205  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2206  (KXdlPack * 64 / MPerXdl),
2207  64 * KXdlPack * MXdlPack / scale_pack_size_a));
2208 
2209  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
2210  make_tuple(problem.N / (NXdlPack * NPerXdl),
2211  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2212  (KXdlPack * 64 / NPerXdl),
2213  64 * KXdlPack * NXdlPack / scale_pack_size_b));
2214 
2215  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2217  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2218 
2219  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2220  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2221  if(expert_block_id * MPerBlock >= max_token_id)
2222  return;
2223  const index_t expert_id =
2224  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2225  const auto block_mn = [&]() -> std::pair<int, int> {
2226  if constexpr(NSwizzle)
2227  {
2228  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2229  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2230  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2231  const index_t expert_swizzle =
2232  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2233  const index_t bid_new = blockIdx.x - prefix_block;
2234  const index_t nid = __builtin_amdgcn_readfirstlane(
2235  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2236  const index_t mid =
2237  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2238  return {nid, mid};
2239  }
2240  else
2241  {
2242  return {blockIdx.x, blockIdx.y};
2243  }
2244  }();
2245 
2246  const index_t block_n_id = block_mn.first;
2247  const index_t block_m_id = block_mn.second;
2248  const index_t token0 =
2249  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2250 
2251  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2252  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2253  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2254  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2255  constexpr auto AKThreads = AK0Threads * AK1Threads;
2256  constexpr auto AMRepeats = MPerBlock / AMThreads;
2257  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2258 
2259  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2260  return;
2262  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2263  const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2264  index_t token_offset = fused_token & 0xffffff;
2265  if constexpr(!IsInputGemm)
2266  {
2267  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2268  }
2269  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2270  });
2271 
2272  const index_t expert_stride =
2273  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2274  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2275  problem.N * (IsInputGemm ? 2 : 1) *
2276  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2277 
2278  // N0, K0, Blocksize*KPack
2279  const index_t n_block_data_idx_on_grid =
2280  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
2281 
2282  // Gride buffer creation
2283  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2284  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2285  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2286  p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2287 
2288  // A, B scale buffer
2289  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2290  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2291  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2292  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2293  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2294 
2295  // lds max alignment
2296  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
2297 
2298  // A matrix in LDS memory, dst of blockwise copy
2299  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2300 
2301  // B matrix in LDS memory, dst of blockwise copy
2302  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2303 
2304  // A matrix blockwise direct to LDS copy
2305  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
2308  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2309  ABlockTransferThreadClusterArrangeOrder,
2310  ADataType,
2311  ADataType,
2312  decltype(a_grid_desc_ak0_m_ak1),
2313  decltype(a_block_desc_ak0_m_ak1),
2314  ABlockTransferSrcAccessOrder,
2315  ABlockTransferSrcVectorDim,
2316  2,
2317  ABlockTransferSrcScalarPerVector,
2318  IndexType,
2319  1>(a_grid_desc_ak0_m_ak1,
2320  make_multi_index(0, 0, 0),
2321  a_block_desc_ak0_m_ak1,
2322  make_multi_index(0, 0, 0),
2323  gather_offsets);
2324 
2325  // B matrix blockwise copy
2326  auto b_blockwise_copy =
2329  BBlockTransferThreadClusterLengths_BK0_N_BK1,
2330  BBlockTransferThreadClusterArrangeOrder,
2331  BDataType,
2332  BDataType,
2333  decltype(b_grid_desc_bk0_n_bk1),
2334  decltype(b_block_desc_bk0_n_bk1),
2335  BBlockTransferSrcAccessOrder,
2336  BBlockTransferSrcVectorDim,
2337  2,
2338  BBlockTransferSrcScalarPerVector>(
2339  b_grid_desc_bk0_n_bk1,
2340  make_multi_index(0, n_block_data_idx_on_grid, 0),
2341  b_block_desc_bk0_n_bk1,
2342  make_multi_index(0, 0, 0));
2343 
2344  // LDS allocation for A and B: be careful of alignment
2345  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
2346  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2347 
2348  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2349  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2350 
2351  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2352  bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
2353  a_block_space_size_aligned * sizeof(ADataType)),
2354  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2355 
2356  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2357  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2358 
2359  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2360  bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2361  a_block_space_size_aligned * sizeof(ADataType)),
2362  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2363 
2364  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2365  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2366 
2367  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2368  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
2369 
2370  // Blockwise GEMM pipeline
2371  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2372  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2373  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2374  decltype(c_thread_buf) c_thread_buf_up;
2375 
2377  float,
2378  c_thread_buf.num_of_v_,
2379  c_thread_buf.s_per_v,
2380  true>
2381  c_thread_buf_fp32;
2382 
2383  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2384  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2385  KPerBlock);
2386 
2387  // a and b scale processing
2388  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2389  const auto waveId_m = wave_idx[I0];
2390  const auto waveId_n = wave_idx[I1];
2391 
2392  auto thread_offset_shuffled =
2393  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2394 
2395  auto a_thread_offset_m = waveId_m;
2396 
2397  // get each thread's offset int the scale tensor
2398  const index_t token_scale_pos = block_m_id * MPerBlock;
2399  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2400  return;
2401 
2402  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2403  AScaleDataType,
2404  AScaleDataType,
2405  decltype(a_scale_grid_desc_am_ak),
2406  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2407  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2408  Sequence<0, 1, 2>, // DimAccessOrder
2409  2, // SrcVectorDim
2410  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2411  1, // SrcScalarStrideInVector
2412  true>(a_scale_grid_desc_am_ak,
2413  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2414  0,
2415  thread_offset_shuffled / scale_pack_size_a));
2416 
2417  // B scale load
2418  auto b_thread_offset_n = waveId_n;
2419 
2420  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2421  BScaleDataType,
2422  BScaleDataType,
2423  decltype(b_scale_grid_desc_bn_ak),
2424  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2425  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2426  Sequence<0, 1, 2>, // DimAccessOrder
2427  2, // SrcVectorDim
2428  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2429  1, // SrcScalarStrideInVector
2430  true>(b_scale_grid_desc_bn_ak,
2431  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2432  0,
2433  thread_offset_shuffled / scale_pack_size_b));
2434 
2435  if constexpr(IsInputGemm)
2436  {
2437  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2438  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2439  p_b_grid_up + expert_id * expert_stride,
2440  b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2441 
2442  // lds ping pong buffers for up
2443  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
2444  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
2445  auto b_block_buf_up_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2446  bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
2447  a_block_space_size_aligned * sizeof(ADataType) +
2448  b_block_space_size_aligned * sizeof(BDataType)),
2449  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2450  auto b_block_buf_up_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2451  bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2452  a_block_space_size_aligned * sizeof(ADataType) +
2453  b_block_space_size_aligned * sizeof(BDataType)),
2454  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2455 
2456  auto b_block_bufs_up = make_tuple(b_block_buf_up_ping, b_block_buf_up_pong);
2457 
2458  auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
2461  BBlockTransferThreadClusterLengths_BK0_N_BK1,
2462  BBlockTransferThreadClusterArrangeOrder,
2463  BDataType,
2464  BDataType,
2465  decltype(b_grid_desc_bk0_n_bk1),
2466  decltype(b_block_desc_bk0_n_bk1),
2467  BBlockTransferSrcAccessOrder,
2468  BBlockTransferSrcVectorDim,
2469  2,
2470  BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
2471  make_multi_index(0, n_block_data_idx_on_grid, 0),
2472  b_block_desc_bk0_n_bk1,
2473  make_multi_index(0, 0, 0));
2474 
2475  const BScaleDataType* p_b_scale_grid_up =
2476  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
2477  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2478  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
2479  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2480 
2481  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2482  BScaleDataType,
2483  BScaleDataType,
2484  decltype(b_scale_grid_desc_bn_ak),
2485  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2486  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2487  Sequence<0, 1, 2>, // DimAccessOrder
2488  2, // SrcVectorDim
2489  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2490  1, // SrcScalarStrideInVector
2491  true>(
2492  b_scale_grid_desc_bn_ak,
2493  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2494  0,
2495  thread_offset_shuffled / scale_pack_size_b));
2496 
2497  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2498  // A
2499  a_grid_desc_ak0_m_ak1,
2500  a_block_desc_ak0_m_ak1,
2501  a_blockwise_copy,
2502  a_grid_buf,
2503  a_block_bufs,
2504  a_block_slice_copy_step,
2505  // Gate and Up
2506  b_grid_desc_bk0_n_bk1,
2507  b_block_desc_bk0_n_bk1,
2508  b_blockwise_copy,
2509  b_blockwise_copy_up,
2510  b_grid_buf,
2511  b_grid_buf_up,
2512  b_block_bufs,
2513  b_block_bufs_up,
2514  b_block_slice_copy_step,
2515  // C
2516  c_thread_buf,
2517  c_thread_buf_up,
2518  // A scale
2519  a_scale_grid_desc_am_ak,
2520  a_scale_thread_copy,
2521  a_scale_grid_buf,
2522  // B scale
2523  b_scale_grid_desc_bn_ak,
2524  b_scale_thread_copy,
2525  b_scale_thread_copy_up,
2526  b_scale_grid_buf,
2527  b_scale_grid_buf_up,
2528  num_k_block_main_loop);
2529  }
2530  else
2531  {
2532  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2533  a_grid_desc_ak0_m_ak1, // A
2534  a_block_desc_ak0_m_ak1,
2535  a_blockwise_copy,
2536  a_grid_buf,
2537  a_block_bufs,
2538  a_block_slice_copy_step,
2539  b_grid_desc_bk0_n_bk1, // B
2540  b_block_desc_bk0_n_bk1,
2541  b_blockwise_copy,
2542  b_grid_buf,
2543  b_block_bufs,
2544  b_block_slice_copy_step,
2545  c_thread_buf, // C
2546  a_scale_grid_desc_am_ak, // A scale
2547  a_scale_thread_copy,
2548  a_scale_grid_buf,
2549  b_scale_grid_desc_bn_ak, // B scale
2550  b_scale_thread_copy,
2551  b_scale_grid_buf,
2552  num_k_block_main_loop);
2553  }
2554 
2555  // shuffle C and write out
2556  {
2557  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2558  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2559  "wrong!");
2560  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
2561  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
2562  "wrong!");
2563 
2564  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2565  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2566 
2567  // TODO: hacky, fix it!
2568  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2569  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2570 
2571  // TODO: hacky, fix it!
2572  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2573  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2574  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2575 
2576  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2577  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2578  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2579  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2580  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2581  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2582  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2583  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2584  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
2585  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
2586 
2587  // mul scales
2588 
2589  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2590  static_assert(M5 == 4);
2591  const index_t m1 = get_warp_local_1d_id() / NWave;
2592  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
2593 
2594  vector_type<float, 4> topk_weights; // for gemm2 only
2595  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
2596  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
2597  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
2598  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
2599  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
2600  const index_t m_pos = block_m_id * MPerBlock +
2601  m0 * M2 * M1 * M3 * M4 * M5 +
2602  m1 * M2 * M3 * M4 * M5 +
2603  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2604  if constexpr(MulRoutedWeight)
2605  {
2606  topk_weights =
2607  *c_style_pointer_cast<const vector_type<float, M5>*>(
2608  p_ds_grid[I2] + m_pos);
2609  }
2610  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
2611  constexpr index_t c_offset =
2612  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2613  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2614  constexpr auto cidx = Number<c_offset>{};
2615 
2616  if constexpr(IsInputGemm) // gu fusion
2617  {
2618  if constexpr(ActivationOperation ==
2620  {
2621  float gate = c_thread_buf[cidx];
2622  float up = c_thread_buf_up[cidx];
2623  if constexpr(MulRoutedWeight)
2624  {
2625  gate = gate * topk_weights.AsType<float>()[m5];
2626  up = up * topk_weights.AsType<float>()[m5];
2627  }
2629  c_thread_buf_fp32(cidx) = gate * up;
2630  }
2631  else if(ActivationOperation == Activation::gelu_and_mul)
2632  {
2633  float gate = c_thread_buf[cidx];
2634  float up = c_thread_buf_up[cidx];
2635  if constexpr(MulRoutedWeight)
2636  {
2637  gate = gate * topk_weights.AsType<float>()[m5];
2638  up = up * topk_weights.AsType<float>()[m5];
2639  }
2641  c_thread_buf_fp32(cidx) = gate * up;
2642  }
2643  }
2644  else
2645  {
2646  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2647  if constexpr(MulRoutedWeight)
2648  {
2649  c_thread_buf_fp32(cidx) =
2650  topk_weights.AsType<float>()[m5] *
2651  c_thread_buf_fp32[cidx];
2652  }
2653  }
2654  });
2655  });
2656  });
2657  });
2658  });
2659  });
2660 
2661  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2663 
2664  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2665  static_cast<CShuffleDataType*>(p_shared_0),
2666  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2667 
2668  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2669  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2670  make_tuple(
2673  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
2674  // shuffle
2675  M1, // M1 = MWave
2676  M2, // M2 * M3 * M4 = MPerXdl
2677  M3,
2678  M4,
2679  M5)),
2683  // per shuffle
2684  N1, // N1 = NWave
2685  N2, // N2 = NXdlPack
2686  N3))), // N3 = NPerXdl
2690  Sequence<>{},
2692 
2693  // calculate origin of thread output tensor on global memory
2694  // blockwise GEMM c matrix starting index
2695  const auto c_thread_mtx_on_block =
2696  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2697 
2698  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2699  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2700 
2701  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2703  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
2705  make_tuple(Sequence<0>{}));
2706 
2707  const auto m_thread_data_on_block_idx =
2708  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2709  make_multi_index(m_thread_data_on_block));
2710 
2711  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2713  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
2715  make_tuple(Sequence<0>{}));
2716 
2717  const auto n_thread_data_on_block_idx =
2718  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2719  make_multi_index(n_thread_data_on_block));
2720 
2721  // shuffle: threadwise copy C from VGPR to LDS
2722  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2723  AccDataType,
2724  CShuffleDataType,
2725  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2726  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2728  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2729  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2730  I1,
2731  I1,
2732  M2,
2733  N2,
2734  M3,
2735  I1,
2736  M5,
2737  I1>,
2739  9,
2740  1,
2742  1,
2743  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2744  make_multi_index(0,
2745  0,
2746  m_thread_data_on_block_idx[I1],
2747  n_thread_data_on_block_idx[I1],
2748  m_thread_data_on_block_idx[I2],
2749  n_thread_data_on_block_idx[I2],
2750  m_thread_data_on_block_idx[I3],
2751  m_thread_data_on_block_idx[I4],
2752  m_thread_data_on_block_idx[I5],
2753  n_thread_data_on_block_idx[I3]),
2755 
2756  using EDataType = CDataType;
2757 
2758  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2759  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2760 
2761  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2763  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2764 
2765  const auto ds_grid_buf = generate_tuple(
2766  [&](auto i) {
2767  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2768  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2769  },
2770  Number<NumDTensor>{});
2771 
2772  // tuple of reference to C/Ds tensor descriptors
2773  const auto c_ds_desc_refs = concat_tuple_of_reference(
2774  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2775  generate_tie([&](auto i) -> const auto& // return type should be reference
2776  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2777  Number<NumDTensor>{}));
2778 
2779  // tuple of reference to C/Ds tensor descriptors
2780  const auto c_ds_buf_refs = concat_tuple_of_reference(
2781  tie(c_shuffle_block_buf),
2782  generate_tie([&](auto i) -> const auto& // return type should be reference
2783  { return ds_grid_buf[i]; },
2784  Number<NumDTensor>{}));
2785 
2786  // tuple of starting index of C/Ds blockwise copy
2787  const auto idx_c_ds_block_begin =
2790  [&](auto) {
2791  return make_multi_index(block_m_id, 0, block_n_id, 0);
2792  // return make_multi_index(block_work_idx[I0], 0,
2793  // block_work_idx[I1], 0);
2794  },
2795  Number<NumDTensor>{}));
2796 
2797  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2798  c_grid_desc_mblock_mperblock_nblock_nperblock;
2799 
2800  using CDEBlockTransferCluster =
2801  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2802  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2803  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2804  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2806  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2808  decltype(c_ds_desc_refs),
2809  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2810  CElementwiseOperation,
2811  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2812  // Sequence support
2813  // arbitray type
2814  Sequence<1,
2815  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2816  1,
2817  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2818  CDEBlockTransferCluster,
2819  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2820  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2821  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2822  3, // index_t SrcVectorDim,
2823  3, // index_t DstVectorDim,
2824  CDEShuffleBlockTransferScalarPerVectors,
2829  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2830  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2831  IndexType,
2832  1, // ScatterDim
2833  true, // OutputScatter: false, only use scatter weights
2834  scatter_weight_idx // ScatterWeightIdx: ascale
2835  >{c_ds_desc_refs,
2836  idx_c_ds_block_begin,
2837  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2838  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2839  c_element_op};
2840 
2841  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2842  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2843 
2844  constexpr auto sfc_c_vgpr =
2845  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2846  NXdlPerWave / NXdlPack,
2847  1,
2848  1,
2849  MXdlPack,
2850  NXdlPack,
2851  M2,
2852  1,
2853  M4,
2854  1>,
2856  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2857  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2858  1,
2859  1,
2860  MXdlPack,
2861  NXdlPack,
2862  M2,
2863  1,
2864  M4,
2865  1>>{};
2866 
2867  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2868 
2869  // space filling curve for shuffled blockwise C/D/E
2870  constexpr auto sfc_cde_block =
2873  Sequence<1,
2874  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2875  1,
2876  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2877 
2878  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2879  constexpr auto EMThreads =
2880  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2881  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2882  constexpr auto ENThreads =
2883  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2884  static_for<0, num_access, 1>{}([&](auto access_id) {
2885  // make sure it's safe to write to LDS
2887 
2888  auto dstidx = sfc_cde_block.GetIndex(access_id);
2889  const index_t c_token_pos =
2890  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2891  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2892  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2893  IndexType token_offset = fused_token & 0xffffff;
2894  if constexpr(IsInputGemm)
2895  {
2896  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2897  }
2898  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2899  });
2900 
2901  block_sync_lds();
2902 
2903  // each thread write its data from VGPR to LDS
2904  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2905  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2906  c_thread_buf_fp32,
2907  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2908  c_shuffle_block_buf);
2909 
2910  // make sure it's safe to read from LDS
2911  block_sync_lds();
2912 
2913  // each block copy its data from LDS to global
2914  cde_block_copy_lds_and_global.Run(
2915  c_ds_desc_refs,
2916  c_ds_buf_refs,
2917  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2918  tie(c_grid_buf),
2919  scatter_offsets);
2920 
2921  if constexpr(access_id < num_access - 1)
2922  {
2923  constexpr auto cde_lds_and_global_step =
2924  sfc_cde_block.GetForwardStep(access_id);
2925 
2926  // move on Ds
2927  static_for<0, NumDTensor, 1>{}([&](auto i) {
2928  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2929  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2930  });
2931 
2932  // move on E
2933  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2934  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2935  I0,
2936  cde_lds_and_global_step);
2937  }
2938  });
2939  }
2940  }
2941 };
2942 
2943 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
constexpr auto BlockGemmMXPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp:36
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:90
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: gridwise_moe_mx_gemm.hpp:721
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm.hpp:783
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm.hpp:789
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm.hpp:791
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm.hpp:782
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm.hpp:793
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm.hpp:786
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:787
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm.hpp:788
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm.hpp:722
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm.hpp:781
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm.hpp:792
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:785
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm.hpp:784
Definition: gridwise_moe_mx_gemm.hpp:649
index_t MBlock
Definition: gridwise_moe_mx_gemm.hpp:715
index_t NPadded
Definition: gridwise_moe_mx_gemm.hpp:710
index_t K
Definition: gridwise_moe_mx_gemm.hpp:701
index_t N
Definition: gridwise_moe_mx_gemm.hpp:700
index_t NumTokens
Definition: gridwise_moe_mx_gemm.hpp:697
index_t M
Definition: gridwise_moe_mx_gemm.hpp:699
index_t StrideA
Definition: gridwise_moe_mx_gemm.hpp:702
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm.hpp:705
index_t KRead
Definition: gridwise_moe_mx_gemm.hpp:711
index_t NBlock
Definition: gridwise_moe_mx_gemm.hpp:716
index_t StrideC
Definition: gridwise_moe_mx_gemm.hpp:707
index_t StrideB
Definition: gridwise_moe_mx_gemm.hpp:704
__host__ void Print() const
Definition: gridwise_moe_mx_gemm.hpp:685
index_t BK0
Definition: gridwise_moe_mx_gemm.hpp:714
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm.hpp:703
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm.hpp:706
index_t MPadded
Definition: gridwise_moe_mx_gemm.hpp:709
index_t KBatch
Definition: gridwise_moe_mx_gemm.hpp:708
index_t KPadded
Definition: gridwise_moe_mx_gemm.hpp:712
index_t TopK
Definition: gridwise_moe_mx_gemm.hpp:698
index_t AK0
Definition: gridwise_moe_mx_gemm.hpp:713
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm.hpp:650
Definition: gridwise_moe_mx_gemm.hpp:797
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm.hpp:798
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:852
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:854
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:851
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:853
Definition: gridwise_moe_mx_gemm.hpp:179
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm.hpp:248
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm.hpp:218
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm.hpp:233
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm.hpp:1179
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm.hpp:625
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:637
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:279
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm.hpp:197
static constexpr auto I7
Definition: gridwise_moe_mx_gemm.hpp:190
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm.hpp:219
remove_cvref_t< decltype(BlockGemmMXPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm.hpp:1138
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:258
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:285
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm.hpp:210
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm.hpp:978
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm.hpp:337
static constexpr auto I6
Definition: gridwise_moe_mx_gemm.hpp:189
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1364
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm.hpp:199
static constexpr auto I9
Definition: gridwise_moe_mx_gemm.hpp:192
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:291
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm.hpp:244
static constexpr auto I8
Definition: gridwise_moe_mx_gemm.hpp:191
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm.hpp:452
static constexpr auto I0
Definition: gridwise_moe_mx_gemm.hpp:183
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm.hpp:181
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm.hpp:561
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm.hpp:209
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm.hpp:580
static constexpr auto I3
Definition: gridwise_moe_mx_gemm.hpp:186
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm.hpp:1392
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm.hpp:1391
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:298
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:268
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm.hpp:206
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm.hpp:231
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:303
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm.hpp:857
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm.hpp:194
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm.hpp:2169
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm.hpp:180
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm.hpp:571
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm.hpp:246
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm.hpp:202
static constexpr auto I1
Definition: gridwise_moe_mx_gemm.hpp:184
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm.hpp:200
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm.hpp:204
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm.hpp:604
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:263
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm.hpp:1140
static constexpr auto I5
Definition: gridwise_moe_mx_gemm.hpp:188
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm.hpp:227
static constexpr auto I4
Definition: gridwise_moe_mx_gemm.hpp:187
static constexpr auto I2
Definition: gridwise_moe_mx_gemm.hpp:185
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm.hpp:313
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm.hpp:208
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm.hpp:203
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm.hpp:1095
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:1372
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:273
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm.hpp:198
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1357
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Definition: thread_group_tensor_slice_transfer_direct_load.hpp:55
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:26
Definition: data_type.hpp:42
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:1007
Definition: unary_element_wise_operation.hpp:334
Definition: unary_element_wise_operation.hpp:1049
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129