/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp Source File
blockwise_gemm_pipeline_wmmaops_base.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
12 
13 namespace ck {
14 
15 template <index_t BlockSize,
16  typename ADataType,
17  typename BDataType,
18  typename ComputeTypeA,
19  typename ComputeTypeB,
20  typename AccDataType,
21  typename AWmmaTileDesc,
22  typename BWmmaTileDesc,
23  index_t ABlockTransferSrcScalarPerVector,
24  index_t BBlockTransferSrcScalarPerVector,
25  index_t MPerBlock,
26  index_t NPerBlock,
27  index_t KPerBlock,
28  index_t MPerWmma,
29  index_t NPerWmma,
30  index_t MRepeat,
31  index_t NRepeat,
32  index_t KPack,
33  bool TransposeC = false>
35 {
36  static constexpr auto I0 = Number<0>{};
37  static constexpr auto I1 = Number<1>{};
38  static constexpr auto I2 = Number<2>{};
39  static constexpr auto I3 = Number<3>{};
40  static constexpr auto I5 = Number<5>{};
41 
43 
44  static constexpr index_t WaveSize = 32;
45 
46  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
47  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
48 
49 #if defined(__gfx12__)
50  static constexpr index_t A_KRow = 2;
51  static constexpr index_t B_KRow = 2;
52 #else
53  static constexpr index_t A_KRow = 1;
54  static constexpr index_t B_KRow = 1;
55 #endif
56 
57  static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5);
58  static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5);
59 
60  static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!");
61  static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");
62 
63  static constexpr auto wmma_gemm =
65 
66  static constexpr index_t KRepeat = KPerBlock / KPack;
67 
68  static constexpr auto WmmaK = Number<wmma_gemm.wmma_instr.k_per_wmma>{};
69 
72  MPerBlock,
73  NPerBlock,
74  KPerBlock,
75  ABlockTransferSrcScalarPerVector,
76  BBlockTransferSrcScalarPerVector,
77  A_K1,
78  B_K1,
79  A_K1,
80  B_K1,
81  MRepeat,
82  NRepeat,
83  MPerWmma,
84  NPerWmma,
85  wmma_gemm.wmma_instr.k_per_wmma>;
86 
88  AccDataType,
89  MRepeat * NRepeat,
90  wmma_gemm.GetRegSizePerWmma(),
91  true>
93 
94  struct Empty
95  {
96  __device__ Empty() {};
97  template <index_t NBuffer>
98  __device__ void GlobalLoad(bool cond)
99  {
100  ignore = NBuffer;
101  ignore = cond;
102  }
103  };
104 
105  template <index_t ScaleSliceSizeN,
106  index_t ScaleSliceSizeK,
107  index_t NWaves,
108  index_t ScaleBlockK,
109  index_t NumberOfBuffers,
110  typename GridDesc,
111  typename ThreadCopy,
112  typename GridBuffer,
113  typename ThreadStaticBuffer,
114  typename BScaleThreadDesc>
115  struct BScale
116  {
117  __device__ BScale(GridDesc b_scale_grid_desc_,
118  ThreadCopy b_scale_thread_copy_,
119  GridBuffer b_scale_grid_buf_)
120  : b_scale_thread_copy(b_scale_thread_copy_),
121  b_scale_grid_desc(b_scale_grid_desc_),
122  b_scale_grid_buf(b_scale_grid_buf_) {};
123 
124  static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{});
126 
127  static constexpr auto b_scale_thread_desc = BScaleThreadDesc{};
128 
129  static constexpr auto b_scale_thread_copy_step =
130  make_tuple(make_multi_index(NWaves * NPerWmma, 0),
131  make_multi_index(-NPerBlock, 0),
132  make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK));
133 
134  template <index_t NBuffer>
135  __device__ void GlobalLoad(bool cond)
136  {
137  static_for<0, NRepeat, 1>{}([&](auto n0) {
141  make_tuple(n0, Number<0>{}),
143 
144  b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
146  });
147 
148  if(cond)
149  {
150  b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
152  }
153  else
154  {
155  b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
157  }
158  }
159 
162  GridBuffer b_scale_grid_buf;
164  };
165 
166  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
167 
168  __device__ static auto GetWaveIdx()
169  {
170  const index_t thread_id = ThisThreadBlock::GetThreadId();
171 
172  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
173  make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
176 
177  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
178  }
179 
180  __device__ static auto CalculateAThreadOriginDataIndex()
181  {
182  const auto wave_idx = GetWaveIdx();
183 
184  const auto waveId_m = wave_idx[I0];
185 
186  const auto wmma_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
187 
188 #if defined(__gfx12__)
189  const auto wmma_krow = wmma_gemm.GetSubGroupId();
190 #else
191  const auto wmma_krow = 0;
192 #endif
193 
194  // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
195  return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
196  }
197 
198  __device__ static auto CalculateBThreadOriginDataIndex()
199  {
200  const auto wave_idx = GetWaveIdx();
201 
202  const auto waveId_n = wave_idx[I1];
203 
204  const auto wmma_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
205 
206 #if defined(__gfx12__)
207  const auto wmma_krow = wmma_gemm.GetSubGroupId();
208 #else
209  const auto wmma_krow = 0;
210 #endif
211 
212  // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
213  return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
214  }
215 
216  template <index_t m0, index_t n0>
218  {
219  const auto wave_idx = GetWaveIdx();
220 
221  const auto waveId_m = wave_idx[I0];
222  const auto waveId_n = wave_idx[I1];
223 
224  const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
225 
226  constexpr auto mrepeat_mwave_mperwmma_to_m_adaptor = make_single_stage_tensor_adaptor(
227  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWmma))),
230 
231  constexpr auto nrepeat_nwave_nperwmma_to_n_adaptor = make_single_stage_tensor_adaptor(
232  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWmma))),
235 
236  const index_t c_thread_m = mrepeat_mwave_mperwmma_to_m_adaptor.CalculateBottomIndex(
237  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
238  const index_t c_thread_n = nrepeat_nwave_nperwmma_to_n_adaptor.CalculateBottomIndex(
239  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
240 
241  return make_tuple(c_thread_m, c_thread_n);
242  }
243 
244  using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
245 
263  __host__ __device__
264  BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
265  Tuple6 b_origin = CalculateBThreadOriginDataIndex())
266  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
267  {
268  static_assert(AWmmaTileDesc::IsKnownAtCompileTime() &&
269  BWmmaTileDesc::IsKnownAtCompileTime(),
270  "wrong! Desc should be known at compile-time");
271 
272  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
273  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize");
274 
275  static_assert(MPerBlock % (MPerWmma * MRepeat) == 0 &&
276  NPerBlock % (NPerWmma * NRepeat) == 0,
277  "wrong!");
278  }
279 
280  __host__ __device__ static constexpr auto
282  {
283  constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
284  wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
285 
286  constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
287  constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
289  // |MRepeat |MWave |MSubGroup |NRepeat |NWave
290  // |NThreadPerSubGroup |MAccVgprs
291  make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
292  make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
293  Number<NRepeat>{} * MAccVgprs * AccStride,
294  Number<NRepeat>{} * MAccVgprs * AccStride,
295  MAccVgprs * AccStride,
296  MAccVgprs * AccStride,
297  MAccVgprs * AccStride,
298  AccStride));
299  }
300 
301  __host__ __device__ static constexpr auto
303  {
304  constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
306  Number<MWaves>{},
308  Number<NRepeat>{},
309  Number<NWaves>{},
310  Number<NPerWmma>{}));
311 
312  return wmma_gemm
313  .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
314  c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
315  }
316 
317  // Describe how data allocated in thread copy src buffer
318  // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
319  static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1;
320  static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1;
321 
322  protected:
323  static constexpr auto a_thread_desc_ =
325  Number<MRepeat>{},
326  Number<KRepeat>{},
327  I1,
328  I1,
329  Number<A_K1>{}),
330  make_tuple(Number<A_K1>{},
331  Number<KPack / A_KRow>{},
332  Number<KPack / A_KRow * MRepeat>{},
333  I0,
334  I0,
335  I1));
336 
337  static constexpr auto b_thread_desc_ =
339  Number<NRepeat>{},
340  Number<KRepeat>{},
341  I1,
342  I1,
343  Number<B_K1>{}),
344  make_tuple(Number<B_K1>{},
345  Number<KPack / B_KRow>{},
346  Number<KPack / B_KRow * NRepeat>{},
347  I0,
348  I0,
349  I1));
350 
351  // C[M, N, NumRegWmma]
352  static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
353  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
354 
355  using AThreadCopy =
357  ComputeTypeA,
358  decltype(a_block_desc_k0_m0_m1_m2_k1),
359  decltype(a_thread_desc_),
360  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
362  5,
363  A_K1,
364  A_K1>;
365 
366  using BThreadCopy =
368  ComputeTypeB,
369  decltype(b_block_desc_k0_n0_n1_n2_k1),
370  decltype(b_thread_desc_),
371  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
373  5,
374  B_K1,
375  B_K1>;
376 
379 };
380 
381 } // namespace ck
Definition: ck.hpp:267
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:116
static constexpr index_t num_scale_krepeat
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:125
GridBuffer b_scale_grid_buf
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:162
__device__ BScale(GridDesc b_scale_grid_desc_, ThreadCopy b_scale_thread_copy_, GridBuffer b_scale_grid_buf_)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:117
StaticallyIndexedArray< ThreadStaticBuffer, Number< NumberOfBuffers >{}> b_scale_thread_bufs
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:163
static constexpr auto b_scale_thread_copy_step
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:129
static constexpr index_t num_scale_k_block
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:124
static constexpr auto b_scale_thread_desc
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:127
__device__ void GlobalLoad(bool cond)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:135
ThreadCopy b_scale_thread_copy
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:160
GridDesc b_scale_grid_desc
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:161
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:95
__device__ Empty()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:96
__device__ void GlobalLoad(bool cond)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:98
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:377
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:166
static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:320
static constexpr auto I1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:37
static constexpr index_t A_K1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:57
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:217
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:198
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, AccDataType, MRepeat *NRepeat, wmma_gemm.GetRegSizePerWmma(), true > c_thread_buf_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:92
__host__ __device__ BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin=CalculateAThreadOriginDataIndex(), Tuple6 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmWmmaops_pipeline_base.
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:264
static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:319
static constexpr index_t MWaves
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:46
static constexpr auto wmma_gemm
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:63
static constexpr index_t B_KRow
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:54
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:168
static constexpr auto I3
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:39
static constexpr auto I0
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:42
static constexpr index_t B_K1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:58
__host__ static constexpr __device__ auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:281
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:44
__host__ static constexpr __device__ auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:302
static constexpr auto WmmaK
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:68
static constexpr auto I5
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:40
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:378
decltype(CalculateAThreadOriginDataIndex()) Tuple6
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:244
static constexpr index_t KRepeat
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:66
static constexpr index_t NWaves
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:47
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:180
static constexpr index_t A_KRow
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:53
static constexpr auto I2
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:38
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
Definition: wmma_gemm.hpp:663
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33