/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp Source File
blockwise_gemm_pipeline_xdlops_v2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Maximum Global Memory throughput pipeline with >=32KB data in fly
11 // GlobalPrefetchStages: >=2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 0
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::I0;
122  using Base::KRepeat;
123  using Base::xdlops_gemm;
124 
125  using Base::CalculateCThreadOriginDataIndex;
126  using Base::CalculateCThreadOriginDataIndex8D;
127  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
128  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
129  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
130  using Base::GetCThreadBuffer;
131  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
132  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
133  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
134  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
136 
137  using Base::a_block_desc_m0_m1_m2_k;
138  using Base::b_block_desc_n0_n1_n2_k;
139 
140  using Base::AMmaKStride;
141  using Base::BMmaKStride;
142  using Base::WaveSize;
143 
145 
146  static constexpr index_t WgpPerCU =
147  (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
148  static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
149  32768 / WgpPerCU,
150  (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
151  static constexpr index_t PrefetchStages =
152  FullMemBandPrefetchStages >= 2
153  ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
154  : 2;
155 
156  static constexpr index_t PrefillStages = 1;
157  static constexpr index_t GlobalBufferNum = PrefetchStages;
158 
159  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
160  {
161  return num_loop > PrefetchStages;
162  }
163 
164  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
165  {
166  if(num_loop % PrefetchStages == 1)
167  {
168  return TailNumber::One;
169  }
170  else if(num_loop % PrefetchStages == 2)
171  {
172  return TailNumber::Two;
173  }
174  else if(num_loop % PrefetchStages == 3)
175  {
176  return TailNumber::Three;
177  }
178  else if(num_loop % PrefetchStages == 4)
179  {
180  return TailNumber::Four;
181  }
182  else if(num_loop % PrefetchStages == 5)
183  {
184  return TailNumber::Five;
185  }
186  else if(num_loop % PrefetchStages == 6)
187  {
188  return TailNumber::Six;
189  }
190  else if(num_loop % PrefetchStages == 7)
191  {
192  return TailNumber::Seven;
193  }
194  else
195  {
196  return TailNumber::Full;
197  }
198  }
199 
200  template <bool HasMainLoop,
201  TailNumber TailNum,
202  typename AGridDesc,
203  typename ABlockDesc,
204  typename ABlockTransfer,
205  typename AGridBuffer,
206  typename ABlockBuffer,
207  typename ABlockTransferStep,
208  typename BGridDesc,
209  typename BBlockDesc,
210  typename BBlockTransfer,
211  typename BGridBuffer,
212  typename BBlockBuffer,
213  typename BBlockTransferStep,
214  typename CThreadBuffer>
215  __device__ void Run(const AGridDesc& a_grid_desc,
216  const ABlockDesc& a_block_desc,
217  ABlockTransfer& a_blockwise_copy,
218  const AGridBuffer& a_grid_buf,
219  ABlockBuffer& a_block_buf,
220  const ABlockTransferStep& a_block_copy_step,
221  const BGridDesc& b_grid_desc,
222  const BBlockDesc& b_block_desc,
223  BBlockTransfer& b_blockwise_copy,
224  const BGridBuffer& b_grid_buf,
225  BBlockBuffer& b_block_buf,
226  const BBlockTransferStep& b_block_copy_step,
227  CThreadBuffer& c_thread_buf,
228  index_t num_loop) const
229  {
230  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
231  a_thread_desc_.GetElementSpaceSize());
232  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
233  b_thread_desc_.GetElementSpaceSize());
234 
235  // Global prefetch 1
236  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
237  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
238 
239  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
240  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
241 
242  // Initialize C
243  c_thread_buf.Clear();
244 
245  // Local prefill 1
246  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
247  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
248 
249  // Global prefetch [2, PrefetchStages]
250  static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) {
251  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
252  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
253 
254  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
255  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
256  });
257 
258  // main body
259  if constexpr(HasMainLoop)
260  {
261  index_t i = 0;
262  do
263  {
264  static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) {
265  // -------------------------------------------------------------------------------------------
266  block_sync_lds();
267  static_for<0, KRepeat, 1>{}([&](auto k) {
268  static_for<0, MRepeat, 1>{}([&](auto m0) {
269  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
270  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
271  a_block_buf,
272  a_thread_desc_,
273  make_tuple(m0, I0, k, I0),
274  a_thread_buf);
275  });
276  static_for<0, NRepeat, 1>{}([&](auto n0) {
277  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
278  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
279  b_block_buf,
280  b_thread_desc_,
281  make_tuple(n0, I0, k, I0),
282  b_thread_buf);
283  });
284  });
285 
286  static_for<0, KRepeat, 1>{}([&](auto k0) {
287  static_for<0, MRepeat, 1>{}([&](auto m0) {
288  static_for<0, NRepeat, 1>{}([&](auto n0) {
291 
292  static_for<0, KPack, 1>{}([&](auto ik) {
293  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
294  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
295  make_tuple(m0, I0, k0, ik))>{}];
296  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
297  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
298  make_tuple(n0, I0, k0, ik))>{}];
299  });
300 
301  using mfma_input_type =
303  xdlops_gemm.K1PerXdlops>::type;
304 
305  constexpr index_t c_offset =
306  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
307 
308  xdlops_gemm.Run(
309  a_thread_vec.template AsType<mfma_input_type>(),
310  b_thread_vec.template AsType<mfma_input_type>(),
311  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
312  });
313  });
314  });
315 
316  block_sync_lds();
317  a_blockwise_copy.RunWrite(
318  a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
319  b_blockwise_copy.RunWrite(
320  b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
321 
322  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
323  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
324 
325  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
326  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
327  });
328 
329  i += PrefetchStages;
330  } while(i < (num_loop - PrefetchStages));
331  }
332 
333  // tail
334 
335  auto LoopTailFunc = [&](auto tail_num) {
336  static_for<1, tail_num, 1>{}([&](auto iprefetch) {
337  block_sync_lds();
338  static_for<0, KRepeat, 1>{}([&](auto k) {
339  static_for<0, MRepeat, 1>{}([&](auto m0) {
340  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
341  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
342  a_block_buf,
343  a_thread_desc_,
344  make_tuple(m0, I0, k, I0),
345  a_thread_buf);
346  });
347  static_for<0, NRepeat, 1>{}([&](auto n0) {
348  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
349  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
350  b_block_buf,
351  b_thread_desc_,
352  make_tuple(n0, I0, k, I0),
353  b_thread_buf);
354  });
355  });
356 
357  static_for<0, KRepeat, 1>{}([&](auto k0) {
358  static_for<0, MRepeat, 1>{}([&](auto m0) {
359  static_for<0, NRepeat, 1>{}([&](auto n0) {
362 
363  static_for<0, KPack, 1>{}([&](auto ik) {
364  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
365  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
366  make_tuple(m0, I0, k0, ik))>{}];
367  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
368  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
369  make_tuple(n0, I0, k0, ik))>{}];
370  });
371 
372  using mfma_input_type =
374  xdlops_gemm.K1PerXdlops>::type;
375 
376  constexpr index_t c_offset =
377  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
378 
379  xdlops_gemm.Run(
380  a_thread_vec.template AsType<mfma_input_type>(),
381  b_thread_vec.template AsType<mfma_input_type>(),
382  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
383  });
384  });
385  });
386 
387  block_sync_lds();
388  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
389  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
390  });
391 
392  block_sync_lds();
393  static_for<0, KRepeat, 1>{}([&](auto k) {
394  static_for<0, MRepeat, 1>{}([&](auto m0) {
395  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
396  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
397  a_block_buf,
398  a_thread_desc_,
399  make_tuple(m0, I0, k, I0),
400  a_thread_buf);
401  });
402  static_for<0, NRepeat, 1>{}([&](auto n0) {
403  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
404  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
405  b_block_buf,
406  b_thread_desc_,
407  make_tuple(n0, I0, k, I0),
408  b_thread_buf);
409  });
410  });
411 
412  static_for<0, KRepeat, 1>{}([&](auto k0) {
413  static_for<0, MRepeat, 1>{}([&](auto m0) {
414  static_for<0, NRepeat, 1>{}([&](auto n0) {
417 
418  static_for<0, KPack, 1>{}([&](auto ik) {
419  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
420  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
421  make_tuple(m0, I0, k0, ik))>{}];
422  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
423  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
424  make_tuple(n0, I0, k0, ik))>{}];
425  });
426 
427  using mfma_input_type =
428  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
429 
430  constexpr index_t c_offset =
431  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
432 
433  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
434  b_thread_vec.template AsType<mfma_input_type>(),
435  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
436  });
437  });
438  });
439  };
440 
441  if constexpr(TailNum == TailNumber::One)
442  {
443  block_sync_lds();
444  static_for<0, KRepeat, 1>{}([&](auto k) {
445  static_for<0, MRepeat, 1>{}([&](auto m0) {
446  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
447  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
448  a_block_buf,
449  a_thread_desc_,
450  make_tuple(m0, I0, k, I0),
451  a_thread_buf);
452  });
453  static_for<0, NRepeat, 1>{}([&](auto n0) {
454  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
455  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
456  b_block_buf,
457  b_thread_desc_,
458  make_tuple(n0, I0, k, I0),
459  b_thread_buf);
460  });
461  });
462 
463  static_for<0, KRepeat, 1>{}([&](auto k0) {
464  static_for<0, MRepeat, 1>{}([&](auto m0) {
465  static_for<0, NRepeat, 1>{}([&](auto n0) {
468 
469  static_for<0, KPack, 1>{}([&](auto ik) {
470  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
471  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
472  make_tuple(m0, I0, k0, ik))>{}];
473  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
474  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
475  make_tuple(n0, I0, k0, ik))>{}];
476  });
477 
478  using mfma_input_type =
479  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
480 
481  constexpr index_t c_offset =
482  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
483 
484  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
485  b_thread_vec.template AsType<mfma_input_type>(),
486  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
487  });
488  });
489  });
490  }
491  else if constexpr(TailNum == TailNumber::Two)
492  {
493  LoopTailFunc(Number<2>{});
494  }
495  else if constexpr(TailNum == TailNumber::Three)
496  {
497  LoopTailFunc(Number<3>{});
498  }
499  else if constexpr(TailNum == TailNumber::Four)
500  {
501  LoopTailFunc(Number<4>{});
502  }
503  else if constexpr(TailNum == TailNumber::Five)
504  {
505  LoopTailFunc(Number<5>{});
506  }
507  else if constexpr(TailNum == TailNumber::Six)
508  {
509  LoopTailFunc(Number<6>{});
510  }
511  else if constexpr(TailNum == TailNumber::Seven)
512  {
513  LoopTailFunc(Number<7>{});
514  }
515  else if constexpr(TailNum == TailNumber::Full)
516  {
517  LoopTailFunc(Number<PrefetchStages>{});
518  }
519  }
520 
521  protected:
522  using Base::a_thread_copy_;
523  using Base::a_thread_desc_;
524  using Base::b_thread_copy_;
525  using Base::b_thread_desc_;
526  using Base::c_thread_desc_;
527 };
528 
529 template <index_t BlockSize,
530  typename ADataType,
531  typename BDataType,
532  typename ComputeDataType,
533  typename AccDataType,
534  typename ATileDesc,
535  typename BTileDesc,
536  typename AMmaTileDesc,
537  typename BMmaTileDesc,
538  index_t ABlockTransferSrcScalarPerVector,
539  index_t BBlockTransferSrcScalarPerVector,
540  index_t MPerBlock,
541  index_t NPerBlock,
542  index_t KPerBlock,
543  index_t MPerXDL,
544  index_t NPerXDL,
545  index_t MRepeat,
546  index_t NRepeat,
547  index_t KPack
548  // ,bool TransposeC //disable transposec right now...
549  >
551  BlockSize,
552  ADataType,
553  BDataType,
554  ComputeDataType,
555  AccDataType,
556  ATileDesc,
557  BTileDesc,
558  AMmaTileDesc,
559  BMmaTileDesc,
560  ABlockTransferSrcScalarPerVector,
561  BBlockTransferSrcScalarPerVector,
562  MPerBlock,
563  NPerBlock,
564  KPerBlock,
565  MPerXDL,
566  NPerXDL,
567  MRepeat,
568  NRepeat,
569  KPack>
571  ADataType,
572  BDataType,
573  ComputeDataType,
574  AccDataType,
575  ATileDesc,
576  BTileDesc,
577  AMmaTileDesc,
578  BMmaTileDesc,
579  ABlockTransferSrcScalarPerVector,
580  BBlockTransferSrcScalarPerVector,
581  MPerBlock,
582  NPerBlock,
583  KPerBlock,
584  MPerXDL,
585  NPerXDL,
586  MRepeat,
587  NRepeat,
588  KPack>
589 
590 {
592  ADataType,
593  BDataType,
594  ComputeDataType,
595  AccDataType,
596  ATileDesc,
597  BTileDesc,
598  AMmaTileDesc,
599  BMmaTileDesc,
600  ABlockTransferSrcScalarPerVector,
601  BBlockTransferSrcScalarPerVector,
602  MPerBlock,
603  NPerBlock,
604  KPerBlock,
605  MPerXDL,
606  NPerXDL,
607  MRepeat,
608  NRepeat,
609  KPack>;
610  using Base::A_K1;
611  using Base::B_K1;
612  using Base::I0;
613  using Base::I1;
614  using Base::KPerThread;
615  using Base::xdlops_gemm;
616 
617  using Base::CalculateCThreadOriginDataIndex;
618  using Base::CalculateCThreadOriginDataIndex8D;
619  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
620  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
621  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
622  using Base::GetCThreadBuffer;
623  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
624  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
625  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
626  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
627  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
628 
629  using Base::a_block_desc_m0_m1_m2_k;
630  using Base::b_block_desc_n0_n1_n2_k;
631  using Base::WaveSize;
632 
634 
636  static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
637  static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
638 
639  static constexpr index_t WgpPerCU =
640  (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
641  static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
642  32768 / WgpPerCU,
643  (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
644  static constexpr index_t PrefetchStages =
645  FullMemBandPrefetchStages >= 2
646  ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
647  : 2;
648 
649  static constexpr index_t PrefillStages = 1;
650  static constexpr index_t GlobalBufferNum = PrefetchStages;
651 
652  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
653  {
654  return num_loop > PrefetchStages;
655  }
656 
657  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
658  {
659  if(num_loop % PrefetchStages == 1)
660  {
661  return TailNumber::One;
662  }
663  else if(num_loop % PrefetchStages == 2)
664  {
665  return TailNumber::Two;
666  }
667  else if(num_loop % PrefetchStages == 3)
668  {
669  return TailNumber::Three;
670  }
671  else if(num_loop % PrefetchStages == 4)
672  {
673  return TailNumber::Four;
674  }
675  else if(num_loop % PrefetchStages == 5)
676  {
677  return TailNumber::Five;
678  }
679  else if(num_loop % PrefetchStages == 6)
680  {
681  return TailNumber::Six;
682  }
683  else if(num_loop % PrefetchStages == 7)
684  {
685  return TailNumber::Seven;
686  }
687  else
688  {
689  return TailNumber::Full;
690  }
691  }
692 
693  template <bool HasMainLoop,
694  TailNumber TailNum,
695  typename AGridDesc,
696  typename ABlockDesc,
697  typename ABlockTransfer,
698  typename AGridBuffer,
699  typename ABlockBuffer,
700  typename ABlockTransferStep,
701  typename BGridDesc,
702  typename BBlockDesc,
703  typename BBlockTransfer,
704  typename BGridBuffer,
705  typename BBlockBuffer,
706  typename BBlockTransferStep,
707  typename CThreadBuffer>
708  __device__ void Run(const AGridDesc& a_grid_desc,
709  const ABlockDesc& a_block_desc,
710  ABlockTransfer& a_blockwise_copy,
711  const AGridBuffer& a_grid_buf,
712  ABlockBuffer& a_block_buf,
713  const ABlockTransferStep& a_block_copy_step,
714  const BGridDesc& b_grid_desc,
715  const BBlockDesc& b_block_desc,
716  BBlockTransfer& b_blockwise_copy,
717  const BGridBuffer& b_grid_buf,
718  BBlockBuffer& b_block_buf,
719  const BBlockTransferStep& b_block_copy_step,
720  CThreadBuffer& c_thread_buf,
721  index_t num_loop) const
722  {
723  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
724  a_thread_desc_.GetElementSpaceSize());
725  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
726  b_thread_desc_.GetElementSpaceSize());
727 
728  // Global prefetch 1
729  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
730  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
731 
732  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
733  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
734 
735  // Initialize C
736  c_thread_buf.Clear();
737 
738  // Local prefill 1
739  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
740  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
741 
742  // Global prefetch [2, PrefetchStages]
743  static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) {
744  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
745  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
746 
747  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
748  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
749  });
750 
751  // main body
752  if constexpr(HasMainLoop)
753  {
754  index_t i = 0;
755  do
756  {
757  static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) {
758  // -------------------------------------------------------------------------------------------
759  block_sync_lds();
760  static_for<0, KRepeat, 1>{}([&](auto k0) {
761  static_for<0, MRepeat, 1>{}([&](auto m0) {
762  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
763  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
764  a_block_buf,
765  a_thread_desc_,
766  make_tuple(m0, I0, k0, I0),
767  a_thread_buf);
768  });
769  static_for<0, NRepeat, 1>{}([&](auto n0) {
770  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
771  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
772  b_block_buf,
773  b_thread_desc_,
774  make_tuple(n0, I0, k0, I0),
775  b_thread_buf);
776  });
777  __builtin_amdgcn_sched_barrier(0);
778  // NOTE: Synchronize threads in a workgroup at the start of each MAC
779  // cluster, but except the first, as we can shorten non-MAC cluster a bit
780  // and there's no observable negative impact. The desired effect is waves in
781  // a workgroup executing MAC in sync. This avoids some out-of-sync waves
782  // hijacking MAC resource from other workgroups and reducing the chance of
783  // latency hiding by waiting for the rest of the workgroup at the eventual
784  // sync point.
785  if constexpr(k0.value != 0 || KRepeat == 1)
786  {
787  __builtin_amdgcn_s_barrier();
788  __builtin_amdgcn_sched_barrier(0);
789  }
790  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
791  static_for<0, MRepeat, 1>{}([&](auto m0) {
792  static_for<0, NRepeat, 1>{}([&](auto n0) {
795 
796  static_for<0, KPack, 1>{}([&](auto ik) {
797  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
798  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
799  make_tuple(m0, I0, k0, k_ + ik))>{}];
800  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
801  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
802  make_tuple(n0, I0, k0, k_ + ik))>{}];
803  });
804 
805  using mfma_input_type =
807  xdlops_gemm.K1PerXdlops>::type;
808 
809  constexpr index_t c_offset =
810  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
811 
812  // The block_sync_lds() here performs double duty:
813  // A) safeguard against data hazard because barrier from
814  // blockwise_gemm is moved here B) reduce VMEM FIFO congestion
815  // by applying small delays to different wavefronts It is
816  // performed near the end of MAC cluster to minimize lgkmcnt
817  // penalty
818  if constexpr(k0.value == KRepeat - 1 &&
819  k_.value == KPerInnerLoop - KPack &&
820  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
821  {
822  __builtin_amdgcn_sched_barrier(0);
823  block_sync_lds();
824  __builtin_amdgcn_sched_barrier(0);
825  }
826  xdlops_gemm.Run(
827  a_thread_vec.template AsType<mfma_input_type>(),
828  b_thread_vec.template AsType<mfma_input_type>(),
829  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
830  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
831  {
832  __builtin_amdgcn_sched_barrier(0);
833  __builtin_amdgcn_s_setprio(1);
834  __builtin_amdgcn_sched_barrier(0);
835  }
836  });
837  });
838  });
839  __builtin_amdgcn_sched_barrier(0);
840  __builtin_amdgcn_s_setprio(0);
841  __builtin_amdgcn_sched_barrier(0);
842  });
843 
844  // block_sync_lds();
845  a_blockwise_copy.RunWrite(
846  a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
847  b_blockwise_copy.RunWrite(
848  b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
849 
850  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
851  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
852 
853  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
854  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
855  });
856  i += PrefetchStages;
857  } while(i < (num_loop - PrefetchStages));
858  }
859 
860  // tail
861 
862  auto LoopTailFunc = [&](auto tail_num) {
863  static_for<1, tail_num, 1>{}([&](auto iprefetch) {
864  block_sync_lds();
865  static_for<0, KRepeat, 1>{}([&](auto k0) {
866  static_for<0, MRepeat, 1>{}([&](auto m0) {
867  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
868  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
869  a_block_buf,
870  a_thread_desc_,
871  make_tuple(m0, I0, k0, I0),
872  a_thread_buf);
873  });
874  static_for<0, NRepeat, 1>{}([&](auto n0) {
875  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
876  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
877  b_block_buf,
878  b_thread_desc_,
879  make_tuple(n0, I0, k0, I0),
880  b_thread_buf);
881  });
882 
883  __builtin_amdgcn_sched_barrier(0);
884  if constexpr(k0.value != 0 || KRepeat == 1)
885  {
886  __builtin_amdgcn_s_barrier();
887  __builtin_amdgcn_sched_barrier(0);
888  }
889  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
890  static_for<0, MRepeat, 1>{}([&](auto m0) {
891  static_for<0, NRepeat, 1>{}([&](auto n0) {
894 
895  static_for<0, KPack, 1>{}([&](auto ik) {
896  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
897  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
898  make_tuple(m0, I0, k0, k_ + ik))>{}];
899  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
900  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
901  make_tuple(n0, I0, k0, k_ + ik))>{}];
902  });
903 
904  using mfma_input_type =
906  xdlops_gemm.K1PerXdlops>::type;
907 
908  constexpr index_t c_offset =
909  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
910 
911  if constexpr(k0.value == KRepeat - 1 &&
912  k_.value == KPerInnerLoop - KPack &&
913  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
914  {
915  __builtin_amdgcn_sched_barrier(0);
916  block_sync_lds();
917  __builtin_amdgcn_sched_barrier(0);
918  }
919  xdlops_gemm.Run(
920  a_thread_vec.template AsType<mfma_input_type>(),
921  b_thread_vec.template AsType<mfma_input_type>(),
922  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
923  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
924  {
925  __builtin_amdgcn_sched_barrier(0);
926  __builtin_amdgcn_s_setprio(1);
927  __builtin_amdgcn_sched_barrier(0);
928  }
929  });
930  });
931  });
932  __builtin_amdgcn_sched_barrier(0);
933  __builtin_amdgcn_s_setprio(0);
934  __builtin_amdgcn_sched_barrier(0);
935  });
936 
937  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
938  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
939  });
940  block_sync_lds();
941  static_for<0, KRepeat, 1>{}([&](auto k0) {
942  static_for<0, MRepeat, 1>{}([&](auto m0) {
943  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
944  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
945  a_block_buf,
946  a_thread_desc_,
947  make_tuple(m0, I0, k0, I0),
948  a_thread_buf);
949  });
950  static_for<0, NRepeat, 1>{}([&](auto n0) {
951  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
952  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
953  b_block_buf,
954  b_thread_desc_,
955  make_tuple(n0, I0, k0, I0),
956  b_thread_buf);
957  });
958 
959  __builtin_amdgcn_sched_barrier(0);
960  if constexpr(k0.value != 0 || KRepeat == 1)
961  {
962  __builtin_amdgcn_s_barrier();
963  __builtin_amdgcn_sched_barrier(0);
964  }
965  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
966  static_for<0, MRepeat, 1>{}([&](auto m0) {
967  static_for<0, NRepeat, 1>{}([&](auto n0) {
970 
971  static_for<0, KPack, 1>{}([&](auto ik) {
972  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
973  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
974  make_tuple(m0, I0, k0, k_ + ik))>{}];
975  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
976  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
977  make_tuple(n0, I0, k0, k_ + ik))>{}];
978  });
979 
980  using mfma_input_type =
982  xdlops_gemm.K1PerXdlops>::type;
983 
984  constexpr index_t c_offset =
985  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
986 
987  if constexpr(k0.value == KRepeat - 1 &&
988  k_.value == KPerInnerLoop - KPack &&
989  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
990  {
991  __builtin_amdgcn_sched_barrier(0);
992  block_sync_lds();
993  __builtin_amdgcn_sched_barrier(0);
994  }
995  xdlops_gemm.Run(
996  a_thread_vec.template AsType<mfma_input_type>(),
997  b_thread_vec.template AsType<mfma_input_type>(),
998  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
999  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
1000  {
1001  __builtin_amdgcn_sched_barrier(0);
1002  __builtin_amdgcn_s_setprio(1);
1003  __builtin_amdgcn_sched_barrier(0);
1004  }
1005  });
1006  });
1007  });
1008  __builtin_amdgcn_sched_barrier(0);
1009  __builtin_amdgcn_s_setprio(0);
1010  __builtin_amdgcn_sched_barrier(0);
1011  });
1012  };
1013 
1014  if constexpr(TailNum == TailNumber::One)
1015  {
1016  block_sync_lds();
1017  static_for<0, KRepeat, 1>{}([&](auto k0) {
1018  static_for<0, MRepeat, 1>{}([&](auto m0) {
1019  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
1020  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
1021  a_block_buf,
1022  a_thread_desc_,
1023  make_tuple(m0, I0, k0, I0),
1024  a_thread_buf);
1025  });
1026  static_for<0, NRepeat, 1>{}([&](auto n0) {
1027  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
1028  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
1029  b_block_buf,
1030  b_thread_desc_,
1031  make_tuple(n0, I0, k0, I0),
1032  b_thread_buf);
1033  });
1034 
1035  __builtin_amdgcn_sched_barrier(0);
1036  if constexpr(k0.value != 0 || KRepeat == 1)
1037  {
1038  __builtin_amdgcn_s_barrier();
1039  __builtin_amdgcn_sched_barrier(0);
1040  }
1041  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
1042  static_for<0, MRepeat, 1>{}([&](auto m0) {
1043  static_for<0, NRepeat, 1>{}([&](auto n0) {
1046 
1047  static_for<0, KPack, 1>{}([&](auto ik) {
1048  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1049  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1050  make_tuple(m0, I0, k0, k_ + ik))>{}];
1051  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1052  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1053  make_tuple(n0, I0, k0, k_ + ik))>{}];
1054  });
1055 
1056  using mfma_input_type =
1058  xdlops_gemm.K1PerXdlops>::type;
1059 
1060  constexpr index_t c_offset =
1061  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1062 
1063  if constexpr(k0.value == KRepeat - 1 &&
1064  k_.value == KPerInnerLoop - KPack &&
1065  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
1066  {
1067  __builtin_amdgcn_sched_barrier(0);
1068  block_sync_lds();
1069  __builtin_amdgcn_sched_barrier(0);
1070  }
1071  xdlops_gemm.Run(
1072  a_thread_vec.template AsType<mfma_input_type>(),
1073  b_thread_vec.template AsType<mfma_input_type>(),
1074  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1075  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
1076  {
1077  __builtin_amdgcn_sched_barrier(0);
1078  __builtin_amdgcn_s_setprio(1);
1079  __builtin_amdgcn_sched_barrier(0);
1080  }
1081  });
1082  });
1083  });
1084  __builtin_amdgcn_sched_barrier(0);
1085  __builtin_amdgcn_s_setprio(0);
1086  __builtin_amdgcn_sched_barrier(0);
1087  });
1088  }
1089  else if constexpr(TailNum == TailNumber::Two)
1090  {
1091  LoopTailFunc(Number<2>{});
1092  }
1093  else if constexpr(TailNum == TailNumber::Three)
1094  {
1095  LoopTailFunc(Number<3>{});
1096  }
1097  else if constexpr(TailNum == TailNumber::Four)
1098  {
1099  LoopTailFunc(Number<4>{});
1100  }
1101  else if constexpr(TailNum == TailNumber::Five)
1102  {
1103  LoopTailFunc(Number<5>{});
1104  }
1105  else if constexpr(TailNum == TailNumber::Six)
1106  {
1107  LoopTailFunc(Number<6>{});
1108  }
1109  else if constexpr(TailNum == TailNumber::Seven)
1110  {
1111  LoopTailFunc(Number<7>{});
1112  }
1113  else if constexpr(TailNum == TailNumber::Full)
1114  {
1115  LoopTailFunc(Number<PrefetchStages>{});
1116  }
1117  }
1118 
1119  protected:
1120  // K->M loopover
1121  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
1122  make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
1123  make_tuple(Number<KPerInnerLoop>{},
1124  Number<KRepeat * MRepeat * KPerInnerLoop>{},
1125  Number<MRepeat * KPerInnerLoop>{},
1126  I1));
1127 
1128  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
1129  make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
1130  make_tuple(Number<KPerInnerLoop>{},
1131  Number<KRepeat * NRepeat * KPerInnerLoop>{},
1132  Number<NRepeat * KPerInnerLoop>{},
1133  I1));
1134 
1137  decltype(a_block_desc_m0_m1_m2_k),
1138  decltype(a_thread_desc_),
1141  3,
1142  A_K1,
1143  A_K1>;
1144 
1147  decltype(b_block_desc_n0_n1_n2_k),
1148  decltype(b_thread_desc_),
1151  3,
1152  B_K1,
1153  B_K1>;
1154 
1155  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
1156  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
1157  using Base::c_thread_desc_;
1158 };
1159 
1160 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:209
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:708
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:215
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:37
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10