/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimized pipeline
11 // GlobalPrefetchStages: 2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::I0;
122  using Base::I1;
123  using Base::KRepeat;
124  using Base::xdlops_gemm;
125  using typename Base::HotLoopInstList;
126 
127  using Base::CalculateCThreadOriginDataIndex;
128  using Base::CalculateCThreadOriginDataIndex8D;
129  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
130  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
131  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
132  using Base::GetCThreadBuffer;
133  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
134  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
136  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
137  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
138 
139  using Base::a_block_desc_m0_m1_m2_k;
140  using Base::b_block_desc_n0_n1_n2_k;
141 
142  using Base::AMmaKStride;
143  using Base::BMmaKStride;
144 
145  static constexpr index_t PrefetchStages = 2;
146  static constexpr index_t PrefillStages = 1;
147  static constexpr index_t GlobalBufferNum = 1;
148 
149  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
150  {
151  return num_loop > PrefetchStages;
152  }
153 
154  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
155  {
156  ignore = num_loop;
157  return TailNumber::Full;
158  }
159 
160  __device__ static constexpr auto HotLoopScheduler()
161  {
162 #if !defined(__gfx11__) && !defined(__gfx12__)
163  // A/B split schedule
164  // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
165  constexpr auto num_ds_read_inst_a =
166  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
167  ? HotLoopInstList::A_LDS_Read_Inst_Num
168  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
169  constexpr auto num_ds_read_inst_b =
170  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
171  ? HotLoopInstList::B_LDS_Read_Inst_Num
172  : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
173 
174  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
175  constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
176 
177  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
178  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
179 
180  constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
181  constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
182 
183  constexpr auto ds_read_a_issue_cycle =
184  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
185  constexpr auto ds_read_b_issue_cycle =
186  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
187  constexpr auto ds_read_a_mfma_rate =
188  (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
189  constexpr auto ds_read_b_mfma_rate =
190  (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
191 
192  constexpr auto num_dsread_a_mfma =
193  (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
194  constexpr auto num_dsread_b_mfma =
195  (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
196 
197  // stage 1
198  // Separate this part?
199  // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
200  // sizeof(ComputeDataType) / sizeof(BDataType)
201  // ? sizeof(ComputeDataType) / sizeof(ADataType)
202  // : sizeof(ComputeDataType) / sizeof(BDataType);
203  constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
204  constexpr auto num_mfma_per_issue =
205  num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
206  constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
207  constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
208 
210  ignore = i;
211  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
212  ignore = idswrite;
213  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
214  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
215  });
216  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
217  __builtin_amdgcn_sched_group_barrier(
218  0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
219  });
221  ignore = i;
222  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
223  ignore = idswrite;
224  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
225  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
226  });
227  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
228  __builtin_amdgcn_sched_group_barrier(
229  0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
230  });
231 
232  // stage 2
234  if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
235  ds_read_a_mfma_rate)
236  {
237  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
238  }
239  else
240  {
241  __builtin_amdgcn_sched_group_barrier(0x100,
242  num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
243  ds_read_a_mfma_rate,
244  0); // DS read
245  }
246  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
247  });
248 
250  if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
251  ds_read_b_mfma_rate)
252  {
253  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
254  }
255  else
256  {
257  __builtin_amdgcn_sched_group_barrier(0x100,
258  num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
259  ds_read_b_mfma_rate,
260  0); // DS read
261  }
262  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
263  });
264 #endif
265  }
266 
267  template <bool HasMainLoop,
268  TailNumber TailNum,
269  typename AGridDesc,
270  typename ABlockDesc,
271  typename ABlockTransfer,
272  typename AGridBuffer,
273  typename ABlockBuffer,
274  typename ABlockTransferStep,
275  typename BGridDesc,
276  typename BBlockDesc,
277  typename BBlockTransfer,
278  typename BGridBuffer,
279  typename BBlockBuffer,
280  typename BBlockTransferStep,
281  typename CThreadBuffer>
282  __device__ void Run(const AGridDesc& a_grid_desc,
283  const ABlockDesc& a_block_desc,
284  ABlockTransfer& a_blockwise_copy,
285  const AGridBuffer& a_grid_buf,
286  ABlockBuffer& a_block_buf,
287  const ABlockTransferStep& a_block_copy_step,
288  const BGridDesc& b_grid_desc,
289  const BBlockDesc& b_block_desc,
290  BBlockTransfer& b_blockwise_copy,
291  const BGridBuffer& b_grid_buf,
292  BBlockBuffer& b_block_buf,
293  const BBlockTransferStep& b_block_copy_step,
294  CThreadBuffer& c_thread_buf,
295  index_t num_loop) const
296  {
297  __builtin_amdgcn_sched_barrier(0);
298  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
299  a_thread_desc_.GetElementSpaceSize());
300  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
301  b_thread_desc_.GetElementSpaceSize());
302 
303  // Global prefetch 1
304  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
305  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
306 
307  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
308  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
309 
310  // Local prefill 1
311  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
312  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
313 
314  // Global prefetch 2
315  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
316  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
317 
318  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
319  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
320 
321  // Initialize C
322  c_thread_buf.Clear();
323 
324  // Local prefetch 1
325  block_sync_lds();
326  static_for<0, KRepeat, 1>{}([&](auto k0) {
327  static_for<0, MRepeat, 1>{}([&](auto m0) {
328  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
329  make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
330  a_block_buf,
331  a_thread_desc_,
332  make_tuple(m0, I0, k0, I0),
333  a_thread_buf);
334  });
335  static_for<0, NRepeat, 1>{}([&](auto n0) {
336  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
337  make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
338  b_block_buf,
339  b_thread_desc_,
340  make_tuple(n0, I0, k0, I0),
341  b_thread_buf);
342  });
343  });
344 
345  __builtin_amdgcn_sched_barrier(0);
346 
347  // main body
348  if constexpr(HasMainLoop)
349  {
350  index_t i = 0;
351  do
352  {
353  block_sync_lds();
354 
355  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
356  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
357 
358  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
359  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
360 
361  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
362  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
363 
364  static_for<0, KRepeat, 1>{}([&](auto k0) {
365  static_for<0, MRepeat, 1>{}([&](auto m0) {
366  static_for<0, NRepeat, 1>{}([&](auto n0) {
369 
370  static_for<0, KPack, 1>{}([&](auto ik) {
371  a_thread_vec.template AsType<ComputeDataType>()(ik) =
372  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
373  make_tuple(m0, I0, k0, ik))>{}];
374  b_thread_vec.template AsType<ComputeDataType>()(ik) =
375  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
376  make_tuple(n0, I0, k0, ik))>{}];
377  });
378 
379  using mfma_input_type =
380  typename vector_type<ComputeDataType,
381  xdlops_gemm.K1PerXdlops>::type;
382 
383  constexpr index_t c_offset =
384  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
385 
386  xdlops_gemm.Run(
387  a_thread_vec.template AsType<mfma_input_type>(),
388  b_thread_vec.template AsType<mfma_input_type>(),
389  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
390  });
391  });
392  });
393 
394  block_sync_lds();
395 
396  static_for<0, KRepeat, 1>{}([&](auto k0) {
397  static_for<0, MRepeat, 1>{}([&](auto m0) {
398  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
399  make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
400  a_block_buf,
401  a_thread_desc_,
402  make_tuple(m0, I0, k0, I0),
403  a_thread_buf);
404  });
405  static_for<0, NRepeat, 1>{}([&](auto n0) {
406  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
407  make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
408  b_block_buf,
409  b_thread_desc_,
410  make_tuple(n0, I0, k0, I0),
411  b_thread_buf);
412  });
413  });
414 
415  HotLoopScheduler();
416  __builtin_amdgcn_sched_barrier(0);
417 
418  i += 1;
419  } while(i < (num_loop - 1));
420  }
421  // tail
422  if constexpr(TailNum == TailNumber::Full)
423  {
424  static_for<0, KRepeat, 1>{}([&](auto k0) {
425  static_for<0, MRepeat, 1>{}([&](auto m0) {
426  static_for<0, NRepeat, 1>{}([&](auto n0) {
429 
430  static_for<0, KPack, 1>{}([&](auto ik) {
431  a_thread_vec.template AsType<ComputeDataType>()(ik) =
432  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
433  make_tuple(m0, I0, k0, ik))>{}];
434  b_thread_vec.template AsType<ComputeDataType>()(ik) =
435  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
436  make_tuple(n0, I0, k0, ik))>{}];
437  });
438 
439  using mfma_input_type =
440  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
441 
442  constexpr index_t c_offset =
443  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
444 
445  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
446  b_thread_vec.template AsType<mfma_input_type>(),
447  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
448  });
449  });
450  });
451  // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
452  // latency
453  // __builtin_amdgcn_sched_barrier(0);
454  }
455  }
456 
457  protected:
458  using Base::a_thread_copy_;
459  using Base::a_thread_desc_;
460  using Base::b_thread_copy_;
461  using Base::b_thread_desc_;
462  using Base::c_thread_desc_;
463 };
464 
465 } // namespace ck
Definition: ck.hpp:267
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:282
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:37
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10