/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimized pipeline
11 // GlobalPrefetchStages: 2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::I0;
122  using Base::I1;
123  using Base::KRepeat;
124  using Base::xdlops_gemm;
125  using typename Base::HotLoopInstList;
126 
127  using Base::CalculateCThreadOriginDataIndex;
128  using Base::CalculateCThreadOriginDataIndex8D;
129  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
130  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
131  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
132  using Base::GetCThreadBuffer;
133  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
134  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
136  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
137  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
138 
139  using Base::a_block_desc_m0_m1_m2_k;
140  using Base::b_block_desc_n0_n1_n2_k;
141 
142  using Base::AMmaKStride;
143  using Base::BMmaKStride;
144 
146 
147  static constexpr index_t PrefetchStages = 2;
148  static constexpr index_t PrefillStages = 1;
149  static constexpr index_t GlobalBufferNum = 1;
150 
151  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
152  {
153  return num_loop > PrefetchStages;
154  }
155 
156  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
157  {
158  ignore = num_loop;
159  return TailNumber::Full;
160  }
161 
162  __device__ static constexpr auto HotLoopScheduler()
163  {
164 #if !defined(__gfx11__) && !defined(__gfx12__)
165  // A/B split schedule
166  // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
167  constexpr auto num_ds_read_inst_a =
168  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
169  ? HotLoopInstList::A_LDS_Read_Inst_Num
170  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
171  constexpr auto num_ds_read_inst_b =
172  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
173  ? HotLoopInstList::B_LDS_Read_Inst_Num
174  : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
175 
176  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
177  constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
178 
179  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
180  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
181 
182  constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
183  constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
184 
185  constexpr auto ds_read_a_issue_cycle =
186  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
187  constexpr auto ds_read_b_issue_cycle =
188  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
189  constexpr auto ds_read_a_mfma_rate =
190  (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
191  constexpr auto ds_read_b_mfma_rate =
192  (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
193 
194  constexpr auto num_dsread_a_mfma =
195  (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
196  constexpr auto num_dsread_b_mfma =
197  (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
198 
199  // stage 1
200  // Separate this part?
201  // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) >
202  // sizeof(ComputeDataTypeBuf) / sizeof(BDataType)
203  // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType)
204  // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType);
205  constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
206  constexpr auto num_mfma_per_issue =
207  num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
208  constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
209  constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
210 
212  ignore = i;
213  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
214  ignore = idswrite;
215  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
216  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
217  });
218  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
219  __builtin_amdgcn_sched_group_barrier(
220  0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
221  });
223  ignore = i;
224  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
225  ignore = idswrite;
226  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
227  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
228  });
229  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
230  __builtin_amdgcn_sched_group_barrier(
231  0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
232  });
233 
234  // stage 2
236  if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
237  ds_read_a_mfma_rate)
238  {
239  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
240  }
241  else
242  {
243  __builtin_amdgcn_sched_group_barrier(0x100,
244  num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
245  ds_read_a_mfma_rate,
246  0); // DS read
247  }
248  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
249  });
250 
252  if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
253  ds_read_b_mfma_rate)
254  {
255  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
256  }
257  else
258  {
259  __builtin_amdgcn_sched_group_barrier(0x100,
260  num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
261  ds_read_b_mfma_rate,
262  0); // DS read
263  }
264  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
265  });
266 #endif
267  }
268 
269  template <bool HasMainLoop,
270  TailNumber TailNum,
271  typename AGridDesc,
272  typename ABlockDesc,
273  typename ABlockTransfer,
274  typename AGridBuffer,
275  typename ABlockBuffer,
276  typename ABlockTransferStep,
277  typename BGridDesc,
278  typename BBlockDesc,
279  typename BBlockTransfer,
280  typename BGridBuffer,
281  typename BBlockBuffer,
282  typename BBlockTransferStep,
283  typename CThreadBuffer>
284  __device__ void Run(const AGridDesc& a_grid_desc,
285  const ABlockDesc& a_block_desc,
286  ABlockTransfer& a_blockwise_copy,
287  const AGridBuffer& a_grid_buf,
288  ABlockBuffer& a_block_buf,
289  const ABlockTransferStep& a_block_copy_step,
290  const BGridDesc& b_grid_desc,
291  const BBlockDesc& b_block_desc,
292  BBlockTransfer& b_blockwise_copy,
293  const BGridBuffer& b_grid_buf,
294  BBlockBuffer& b_block_buf,
295  const BBlockTransferStep& b_block_copy_step,
296  CThreadBuffer& c_thread_buf,
297  index_t num_loop) const
298  {
299  __builtin_amdgcn_sched_barrier(0);
300  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
301  a_thread_desc_.GetElementSpaceSize());
302  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
303  b_thread_desc_.GetElementSpaceSize());
304 
305  // Global prefetch 1
306  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
307  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
308 
309  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
310  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
311 
312  // Local prefill 1
313  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
314  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
315 
316  // Global prefetch 2
317  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
318  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
319 
320  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
321  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
322 
323  // Initialize C
324  c_thread_buf.Clear();
325 
326  // Local prefetch 1
327  block_sync_lds();
328  static_for<0, KRepeat, 1>{}([&](auto k0) {
329  static_for<0, MRepeat, 1>{}([&](auto m0) {
330  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
331  make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
332  a_block_buf,
333  a_thread_desc_,
334  make_tuple(m0, I0, k0, I0),
335  a_thread_buf);
336  });
337  static_for<0, NRepeat, 1>{}([&](auto n0) {
338  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
339  make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
340  b_block_buf,
341  b_thread_desc_,
342  make_tuple(n0, I0, k0, I0),
343  b_thread_buf);
344  });
345  });
346 
347  __builtin_amdgcn_sched_barrier(0);
348 
349  // main body
350  if constexpr(HasMainLoop)
351  {
352  index_t i = 0;
353  do
354  {
355  block_sync_lds();
356 
357  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
358  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
359 
360  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
361  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
362 
363  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
364  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
365 
366  static_for<0, KRepeat, 1>{}([&](auto k0) {
367  static_for<0, MRepeat, 1>{}([&](auto m0) {
368  static_for<0, NRepeat, 1>{}([&](auto n0) {
371 
372  static_for<0, KPack, 1>{}([&](auto ik) {
373  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
374  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
375  make_tuple(m0, I0, k0, ik))>{}];
376  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
377  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
378  make_tuple(n0, I0, k0, ik))>{}];
379  });
380 
381  using mfma_input_type =
383  xdlops_gemm.K1PerXdlops>::type;
384 
385  constexpr index_t c_offset =
386  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
387 
388  xdlops_gemm.Run(
389  a_thread_vec.template AsType<mfma_input_type>(),
390  b_thread_vec.template AsType<mfma_input_type>(),
391  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
392  });
393  });
394  });
395 
396  block_sync_lds();
397 
398  static_for<0, KRepeat, 1>{}([&](auto k0) {
399  static_for<0, MRepeat, 1>{}([&](auto m0) {
400  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
401  make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
402  a_block_buf,
403  a_thread_desc_,
404  make_tuple(m0, I0, k0, I0),
405  a_thread_buf);
406  });
407  static_for<0, NRepeat, 1>{}([&](auto n0) {
408  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
409  make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
410  b_block_buf,
411  b_thread_desc_,
412  make_tuple(n0, I0, k0, I0),
413  b_thread_buf);
414  });
415  });
416 
417  HotLoopScheduler();
418  __builtin_amdgcn_sched_barrier(0);
419 
420  i += 1;
421  } while(i < (num_loop - 1));
422  }
423  // tail
424  if constexpr(TailNum == TailNumber::Full)
425  {
426  static_for<0, KRepeat, 1>{}([&](auto k0) {
427  static_for<0, MRepeat, 1>{}([&](auto m0) {
428  static_for<0, NRepeat, 1>{}([&](auto n0) {
431 
432  static_for<0, KPack, 1>{}([&](auto ik) {
433  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
434  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
435  make_tuple(m0, I0, k0, ik))>{}];
436  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
437  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
438  make_tuple(n0, I0, k0, ik))>{}];
439  });
440 
441  using mfma_input_type =
442  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
443 
444  constexpr index_t c_offset =
445  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
446 
447  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
448  b_thread_vec.template AsType<mfma_input_type>(),
449  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
450  });
451  });
452  });
453  // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
454  // latency
455  // __builtin_amdgcn_sched_barrier(0);
456  }
457  }
458 
459  protected:
460  using Base::a_thread_copy_;
461  using Base::a_thread_desc_;
462  using Base::b_thread_copy_;
463  using Base::b_thread_desc_;
464  using Base::c_thread_desc_;
465 };
466 
467 } // namespace ck
Definition: ck.hpp:268
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:284
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:37
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10