/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File
blockwise_gemm_pipeline_wmmaops_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimized pipeline
11 // GlobalPrefetchStages: 2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeTypeA,
21  typename ComputeTypeB,
22  typename AccDataType,
23  typename AWmmaTileDesc,
24  typename BWmmaTileDesc,
25  index_t ABlockTransferSrcScalarPerVector,
26  index_t BBlockTransferSrcScalarPerVector,
27  index_t MPerBlock,
28  index_t NPerBlock,
29  index_t KPerBlock,
30  index_t MPerWmma,
31  index_t NPerWmma,
32  index_t MRepeat,
33  index_t NRepeat,
34  index_t KPack>
36 {
37 };
38 
39 template <index_t BlockSize,
40  typename ADataType,
41  typename BDataType,
42  typename ComputeTypeA,
43  typename ComputeTypeB,
44  typename AccDataType,
45  typename AWmmaTileDesc,
46  typename BWmmaTileDesc,
47  index_t ABlockTransferSrcScalarPerVector,
48  index_t BBlockTransferSrcScalarPerVector,
49  index_t MPerBlock,
50  index_t NPerBlock,
51  index_t KPerBlock,
52  index_t MPerWmma,
53  index_t NPerWmma,
54  index_t MRepeat,
55  index_t NRepeat,
56  index_t KPack>
58  BlockSize,
59  ADataType,
60  BDataType,
61  ComputeTypeA,
62  ComputeTypeB,
63  AccDataType,
64  AWmmaTileDesc,
65  BWmmaTileDesc,
66  ABlockTransferSrcScalarPerVector,
67  BBlockTransferSrcScalarPerVector,
68  MPerBlock,
69  NPerBlock,
70  KPerBlock,
71  MPerWmma,
72  NPerWmma,
73  MRepeat,
74  NRepeat,
75  KPack>
77  ADataType,
78  BDataType,
79  ComputeTypeA,
80  ComputeTypeB,
81  AccDataType,
82  AWmmaTileDesc,
83  BWmmaTileDesc,
84  ABlockTransferSrcScalarPerVector,
85  BBlockTransferSrcScalarPerVector,
86  MPerBlock,
87  NPerBlock,
88  KPerBlock,
89  MPerWmma,
90  NPerWmma,
91  MRepeat,
92  NRepeat,
93  KPack>
94 {
96  ADataType,
97  BDataType,
98  ComputeTypeA,
99  ComputeTypeB,
100  AccDataType,
101  AWmmaTileDesc,
102  BWmmaTileDesc,
103  ABlockTransferSrcScalarPerVector,
104  BBlockTransferSrcScalarPerVector,
105  MPerBlock,
106  NPerBlock,
107  KPerBlock,
108  MPerWmma,
109  NPerWmma,
110  MRepeat,
111  NRepeat,
112  KPack>;
113  using Base::I0;
114 
115  using Base::A_K1;
116  using Base::A_KRow;
117  using Base::B_K1;
118  using Base::B_KRow;
119  using Base::KRepeat;
120  using Base::WmmaK;
121 
122  using Base::wmma_gemm;
123  using typename Base::HotLoopInstList;
124 
125  using Base::CalculateCThreadOriginDataIndex;
126  using Base::
127  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
128  using Base::GetCThreadBuffer;
129  using Base::
130  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
131 
132  using Base::a_block_desc_k0_m0_m1_m2_k1;
133  using Base::b_block_desc_k0_n0_n1_n2_k1;
134 
135  using typename Base::Empty;
136 
137  static constexpr index_t PrefetchStages = 2;
138  static constexpr index_t PrefillStages = 1;
139  static constexpr index_t GlobalBufferNum = 1;
140 
141  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
142  {
143  return num_loop > PrefetchStages;
144  }
145 
146  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
147  {
148  ignore = num_loop;
149  return TailNumber::Full;
150  }
151 
152  __device__ static constexpr auto HotLoopScheduler()
153  {
154  // TODO: Calculation of the number of instructions may require changes for WMMA
155  /*
156  // A/B split schedule
157  // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
158  constexpr auto num_ds_read_inst_a =
159  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
160  ? HotLoopInstList::A_LDS_Read_Inst_Num
161  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
162  constexpr auto num_ds_read_inst_b =
163  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
164  ? HotLoopInstList::B_LDS_Read_Inst_Num
165  : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
166 
167  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
168  constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
169 
170  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
171  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
172 
173  constexpr auto num_wmma_inst = HotLoopInstList::C_WMMA_Inst_Num;
174 
175  constexpr auto wmma_cycle = NPerWmma == 16 ? 16 : 32;
176  constexpr auto ds_read_a_issue_cycle =
177  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
178  constexpr auto ds_read_b_issue_cycle =
179  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
180  constexpr auto ds_read_a_wmma_rate =
181  (wmma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
182  constexpr auto ds_read_b_wmma_rate =
183  (wmma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
184 
185  constexpr auto num_dsread_a_wmma =
186  (num_ds_read_inst_a + ds_read_a_wmma_rate - 1) / ds_read_a_wmma_rate;
187  constexpr auto num_dsread_b_wmma =
188  (num_ds_read_inst_b + ds_read_b_wmma_rate - 1) / ds_read_b_wmma_rate;
189 
190  // stage 1
191  // Separate this part?
192  // constexpr auto num_wmma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
193  // sizeof(ComputeDataType) / sizeof(BDataType)
194  // ? sizeof(ComputeDataType) / sizeof(ADataType)
195  // : sizeof(ComputeDataType) / sizeof(BDataType);
196  constexpr auto num_wmma_stage1 = num_wmma_inst - (num_dsread_a_wmma + num_dsread_b_wmma);
197  constexpr auto num_wmma_per_issue =
198  num_wmma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
199  constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
200  constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
201 
202  static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
203  ignore = i;
204  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
205  ignore = idswrite;
206  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
207  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
208  });
209  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
210  __builtin_amdgcn_sched_group_barrier(
211  0x008, num_wmma_per_issue - num_dswrite_per_issue_a, 0); // WMMA
212  });
213  static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
214  ignore = i;
215  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
216  ignore = idswrite;
217  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
218  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
219  });
220  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
221  __builtin_amdgcn_sched_group_barrier(
222  0x008, num_wmma_per_issue - num_dswrite_per_issue_b, 0); // WMMA
223  });
224 
225  // stage 2
226  static_for<0, num_dsread_a_wmma, 1>{}([&](auto i) {
227  if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_wmma_rate) >=
228  ds_read_a_wmma_rate)
229  {
230  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_wmma_rate, 0); // DS read
231  }
232  else
233  {
234  __builtin_amdgcn_sched_group_barrier(0x100,
235  num_ds_read_inst_a - (num_dsread_a_wmma - 1) *
236  ds_read_a_wmma_rate,
237  0); // DS read
238  }
239  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
240  });
241 
242  static_for<0, num_dsread_b_wmma, 1>{}([&](auto i) {
243  if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_wmma_rate) >=
244  ds_read_b_wmma_rate)
245  {
246  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_wmma_rate, 0); // DS read
247  }
248  else
249  {
250  __builtin_amdgcn_sched_group_barrier(0x100,
251  num_ds_read_inst_b - (num_dsread_b_wmma - 1) *
252  ds_read_b_wmma_rate,
253  0); // DS read
254  }
255  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
256  });
257  */
258  }
259 
260  template <typename ABlockBuffer,
261  typename AThreadBuffer,
262  typename BBlockBuffer,
263  typename BThreadBuffer,
264  typename BScaleStruct>
265  __device__ inline void LocalLoad(ABlockBuffer& a_block_buf,
266  AThreadBuffer& a_thread_buf,
267  BBlockBuffer& b_block_buf,
268  BThreadBuffer& b_thread_buf,
269  BScaleStruct& b_scale_struct) const
270  {
271  static_for<0, KRepeat, 1>{}([&](auto k0) {
272  static_for<0, MRepeat, 1>{}([&](auto m0) {
273  a_thread_copy_.Run(
274  a_block_desc_k0_m0_m1_m2_k1,
275  make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
276  a_block_buf,
277  a_thread_desc_,
278  make_tuple(I0, m0, k0, I0, I0, I0),
279  a_thread_buf);
280  });
281 
282  if constexpr(ck::is_same_v<BScaleStruct, Empty>)
283  {
284  static_for<0, NRepeat, 1>{}([&](auto n0) {
285  b_thread_copy_.Run(
286  b_block_desc_k0_n0_n1_n2_k1,
287  make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
288  b_block_buf,
289  b_thread_desc_,
290  make_tuple(I0, n0, k0, I0, I0, I0),
291  b_thread_buf);
292  });
293  }
294  else
295  {
296  static_for<0, NRepeat, 1>{}([&](auto n0) {
297  b_thread_copy_.Run(
298  b_block_desc_k0_n0_n1_n2_k1,
299  make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
300  b_block_buf,
301  b_scale_struct.b_scale_thread_bufs(
302  I0)[Number<n0 * BScaleStruct::num_scale_k_block +
303  k0 / BScaleStruct::num_scale_krepeat>{}],
304  b_thread_desc_,
305  make_tuple(I0, n0, k0, I0, I0, I0),
306  b_thread_buf);
307  });
308  }
309  });
310  }
311 
312  template <bool HasMainLoop,
313  TailNumber TailNum,
314  typename AGridDesc,
315  typename ABlockDesc,
316  typename ABlockTransfer,
317  typename AGridBuffer,
318  typename ABlockBuffer,
319  typename ABlockTransferStep,
320  typename BGridDesc,
321  typename BBlockDesc,
322  typename BBlockTransfer,
323  typename BGridBuffer,
324  typename BBlockBuffer,
325  typename BBlockTransferStep,
326  typename CThreadBuffer,
327  typename BScaleStruct>
328  __device__ void Run(const AGridDesc& a_grid_desc,
329  const ABlockDesc& a_block_desc,
330  ABlockTransfer& a_blockwise_copy,
331  const AGridBuffer& a_grid_buf,
332  ABlockBuffer& a_block_buf,
333  const ABlockTransferStep& a_block_copy_step,
334  const BGridDesc& b_grid_desc,
335  const BBlockDesc& b_block_desc,
336  BBlockTransfer& b_blockwise_copy,
337  const BGridBuffer& b_grid_buf,
338  BBlockBuffer& b_block_buf,
339  const BBlockTransferStep& b_block_copy_step,
340  CThreadBuffer& c_thread_buf,
341  // BScaleThreadCopy
342  BScaleStruct& b_scale_struct,
343  index_t num_loop,
344  index_t num_loop_per_scale) const
345  {
346  __builtin_amdgcn_sched_barrier(0);
347  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
348  a_thread_desc_.GetElementSpaceSize());
349  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
350  b_thread_desc_.GetElementSpaceSize());
351 
352  // Global prefetch 1
353  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
354  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
355 
356  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
357  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
358 
359  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
360 
361  // Local prefill 1
362  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
363  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
364 
365  // Global prefetch 2
366  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
367  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
368 
369  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
370  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
371 
372  // Initialize C
373  c_thread_buf.Clear();
374 
375  // Local prefetch 1
376  block_sync_lds();
377 
378  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
379 
380  __builtin_amdgcn_sched_barrier(0);
381 
382  // main body
383  if constexpr(HasMainLoop)
384  {
385  index_t i = 0;
386  do
387  {
388  block_sync_lds();
389 
390  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
391  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
392 
393  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
394  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
395 
396  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
397  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
398 
399  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
400 
401  static_for<0, KRepeat, 1>{}([&](auto k0) {
402  static_for<0, MRepeat, 1>{}([&](auto m0) {
403  static_for<0, NRepeat, 1>{}([&](auto n0) {
404  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
405  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
406 
407  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
408  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
409  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
411  m0,
412  k0,
413  I0,
414  I0,
415  Number<ik % A_K1>{}))>{}];
416  });
417  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
418  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
419  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
421  n0,
422  k0,
423  I0,
424  I0,
425  Number<ik % B_K1>{}))>{}];
426  });
427 
428  using wmma_input_type_a =
429  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
430  using wmma_input_type_b =
431  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
432 
433  constexpr index_t c_offset =
434  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
435 
436  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
437  b_thread_vec.template AsType<wmma_input_type_b>(),
438  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
439  });
440  });
441  });
442 
443  block_sync_lds();
444 
445  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
446 
447  HotLoopScheduler();
448  __builtin_amdgcn_sched_barrier(0);
449 
450  i += 1;
451  } while(i < (num_loop - 1));
452  }
453  // tail
454  if constexpr(TailNum == TailNumber::Full)
455  {
456  static_for<0, KRepeat, 1>{}([&](auto k0) {
457  static_for<0, MRepeat, 1>{}([&](auto m0) {
458  static_for<0, NRepeat, 1>{}([&](auto n0) {
459  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
460  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
461 
462  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
463  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
464  a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
465  Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
466  });
467  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
468  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
469  b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
470  Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
471  });
472 
473  using wmma_input_type_a =
474  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
475  using wmma_input_type_b =
476  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
477 
478  constexpr index_t c_offset =
479  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
480 
481  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
482  b_thread_vec.template AsType<wmma_input_type_b>(),
483  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
484  });
485  });
486  });
487  // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
488  // latency
489  // __builtin_amdgcn_sched_barrier(0);
490  }
491  }
492 
493  protected:
494  using Base::a_thread_copy_;
495  using Base::a_thread_desc_;
496  using Base::b_thread_copy_;
497  using Base::b_thread_desc_;
498  using Base::c_thread_desc_;
499 };
500 
501 } // namespace ck
Definition: ck.hpp:267
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:95
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:328
__device__ void LocalLoad(ABlockBuffer &a_block_buf, AThreadBuffer &a_thread_buf, BBlockBuffer &b_block_buf, BThreadBuffer &b_thread_buf, BScaleStruct &b_scale_struct) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:265
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:36
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10