/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp Source File
blockwise_gemm_pipeline_xdlops_v2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Maximum Global Memory throughput pipeline with >=32KB data in fly
11 // GlobalPrefetchStages: >=2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 0
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::I0;
122  using Base::KRepeat;
123  using Base::xdlops_gemm;
124 
125  using Base::CalculateCThreadOriginDataIndex;
126  using Base::CalculateCThreadOriginDataIndex8D;
127  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
128  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
129  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
130  using Base::GetCThreadBuffer;
131  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
132  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
133  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
134  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
136 
137  using Base::a_block_desc_m0_m1_m2_k;
138  using Base::b_block_desc_n0_n1_n2_k;
139 
140  using Base::AMmaKStride;
141  using Base::BMmaKStride;
142  using Base::WaveSize;
143 
144  static constexpr index_t WgpPerCU =
145  (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
146  static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
147  32768 / WgpPerCU,
148  (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
149  static constexpr index_t PrefetchStages =
150  FullMemBandPrefetchStages >= 2
151  ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
152  : 2;
153 
154  static constexpr index_t PrefillStages = 1;
155  static constexpr index_t GlobalBufferNum = PrefetchStages;
156 
157  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
158  {
159  return num_loop > PrefetchStages;
160  }
161 
162  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
163  {
164  if(num_loop % PrefetchStages == 1)
165  {
166  return TailNumber::One;
167  }
168  else if(num_loop % PrefetchStages == 2)
169  {
170  return TailNumber::Two;
171  }
172  else if(num_loop % PrefetchStages == 3)
173  {
174  return TailNumber::Three;
175  }
176  else if(num_loop % PrefetchStages == 4)
177  {
178  return TailNumber::Four;
179  }
180  else if(num_loop % PrefetchStages == 5)
181  {
182  return TailNumber::Five;
183  }
184  else if(num_loop % PrefetchStages == 6)
185  {
186  return TailNumber::Six;
187  }
188  else if(num_loop % PrefetchStages == 7)
189  {
190  return TailNumber::Seven;
191  }
192  else
193  {
194  return TailNumber::Full;
195  }
196  }
197 
198  template <bool HasMainLoop,
199  TailNumber TailNum,
200  typename AGridDesc,
201  typename ABlockDesc,
202  typename ABlockTransfer,
203  typename AGridBuffer,
204  typename ABlockBuffer,
205  typename ABlockTransferStep,
206  typename BGridDesc,
207  typename BBlockDesc,
208  typename BBlockTransfer,
209  typename BGridBuffer,
210  typename BBlockBuffer,
211  typename BBlockTransferStep,
212  typename CThreadBuffer>
213  __device__ void Run(const AGridDesc& a_grid_desc,
214  const ABlockDesc& a_block_desc,
215  ABlockTransfer& a_blockwise_copy,
216  const AGridBuffer& a_grid_buf,
217  ABlockBuffer& a_block_buf,
218  const ABlockTransferStep& a_block_copy_step,
219  const BGridDesc& b_grid_desc,
220  const BBlockDesc& b_block_desc,
221  BBlockTransfer& b_blockwise_copy,
222  const BGridBuffer& b_grid_buf,
223  BBlockBuffer& b_block_buf,
224  const BBlockTransferStep& b_block_copy_step,
225  CThreadBuffer& c_thread_buf,
226  index_t num_loop) const
227  {
228  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
229  a_thread_desc_.GetElementSpaceSize());
230  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
231  b_thread_desc_.GetElementSpaceSize());
232 
233  // Global prefetch 1
234  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
235  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
236 
237  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
238  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
239 
240  // Initialize C
241  c_thread_buf.Clear();
242 
243  // Local prefill 1
244  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
245  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
246 
247  // Global prefetch [2, PrefetchStages]
248  static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) {
249  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
250  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
251 
252  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
253  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
254  });
255 
256  // main body
257  if constexpr(HasMainLoop)
258  {
259  index_t i = 0;
260  do
261  {
262  static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) {
263  // -------------------------------------------------------------------------------------------
264  block_sync_lds();
265  static_for<0, KRepeat, 1>{}([&](auto k) {
266  static_for<0, MRepeat, 1>{}([&](auto m0) {
267  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
268  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
269  a_block_buf,
270  a_thread_desc_,
271  make_tuple(m0, I0, k, I0),
272  a_thread_buf);
273  });
274  static_for<0, NRepeat, 1>{}([&](auto n0) {
275  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
276  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
277  b_block_buf,
278  b_thread_desc_,
279  make_tuple(n0, I0, k, I0),
280  b_thread_buf);
281  });
282  });
283 
284  static_for<0, KRepeat, 1>{}([&](auto k0) {
285  static_for<0, MRepeat, 1>{}([&](auto m0) {
286  static_for<0, NRepeat, 1>{}([&](auto n0) {
289 
290  static_for<0, KPack, 1>{}([&](auto ik) {
291  a_thread_vec.template AsType<ComputeDataType>()(ik) =
292  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
293  make_tuple(m0, I0, k0, ik))>{}];
294  b_thread_vec.template AsType<ComputeDataType>()(ik) =
295  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
296  make_tuple(n0, I0, k0, ik))>{}];
297  });
298 
299  using mfma_input_type =
300  typename vector_type<ComputeDataType,
301  xdlops_gemm.K1PerXdlops>::type;
302 
303  constexpr index_t c_offset =
304  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
305 
306  xdlops_gemm.Run(
307  a_thread_vec.template AsType<mfma_input_type>(),
308  b_thread_vec.template AsType<mfma_input_type>(),
309  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
310  });
311  });
312  });
313 
314  block_sync_lds();
315  a_blockwise_copy.RunWrite(
316  a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
317  b_blockwise_copy.RunWrite(
318  b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
319 
320  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
321  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
322 
323  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
324  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
325  });
326 
327  i += PrefetchStages;
328  } while(i < (num_loop - PrefetchStages));
329  }
330 
331  // tail
332 
333  auto LoopTailFunc = [&](auto tail_num) {
334  static_for<1, tail_num, 1>{}([&](auto iprefetch) {
335  block_sync_lds();
336  static_for<0, KRepeat, 1>{}([&](auto k) {
337  static_for<0, MRepeat, 1>{}([&](auto m0) {
338  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
339  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
340  a_block_buf,
341  a_thread_desc_,
342  make_tuple(m0, I0, k, I0),
343  a_thread_buf);
344  });
345  static_for<0, NRepeat, 1>{}([&](auto n0) {
346  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
347  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
348  b_block_buf,
349  b_thread_desc_,
350  make_tuple(n0, I0, k, I0),
351  b_thread_buf);
352  });
353  });
354 
355  static_for<0, KRepeat, 1>{}([&](auto k0) {
356  static_for<0, MRepeat, 1>{}([&](auto m0) {
357  static_for<0, NRepeat, 1>{}([&](auto n0) {
360 
361  static_for<0, KPack, 1>{}([&](auto ik) {
362  a_thread_vec.template AsType<ComputeDataType>()(ik) =
363  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
364  make_tuple(m0, I0, k0, ik))>{}];
365  b_thread_vec.template AsType<ComputeDataType>()(ik) =
366  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
367  make_tuple(n0, I0, k0, ik))>{}];
368  });
369 
370  using mfma_input_type =
371  typename vector_type<ComputeDataType,
372  xdlops_gemm.K1PerXdlops>::type;
373 
374  constexpr index_t c_offset =
375  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
376 
377  xdlops_gemm.Run(
378  a_thread_vec.template AsType<mfma_input_type>(),
379  b_thread_vec.template AsType<mfma_input_type>(),
380  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
381  });
382  });
383  });
384 
385  block_sync_lds();
386  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
387  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
388  });
389 
390  block_sync_lds();
391  static_for<0, KRepeat, 1>{}([&](auto k) {
392  static_for<0, MRepeat, 1>{}([&](auto m0) {
393  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
394  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
395  a_block_buf,
396  a_thread_desc_,
397  make_tuple(m0, I0, k, I0),
398  a_thread_buf);
399  });
400  static_for<0, NRepeat, 1>{}([&](auto n0) {
401  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
402  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
403  b_block_buf,
404  b_thread_desc_,
405  make_tuple(n0, I0, k, I0),
406  b_thread_buf);
407  });
408  });
409 
410  static_for<0, KRepeat, 1>{}([&](auto k0) {
411  static_for<0, MRepeat, 1>{}([&](auto m0) {
412  static_for<0, NRepeat, 1>{}([&](auto n0) {
415 
416  static_for<0, KPack, 1>{}([&](auto ik) {
417  a_thread_vec.template AsType<ComputeDataType>()(ik) =
418  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
419  make_tuple(m0, I0, k0, ik))>{}];
420  b_thread_vec.template AsType<ComputeDataType>()(ik) =
421  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
422  make_tuple(n0, I0, k0, ik))>{}];
423  });
424 
425  using mfma_input_type =
426  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
427 
428  constexpr index_t c_offset =
429  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
430 
431  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
432  b_thread_vec.template AsType<mfma_input_type>(),
433  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
434  });
435  });
436  });
437  };
438 
439  if constexpr(TailNum == TailNumber::One)
440  {
441  block_sync_lds();
442  static_for<0, KRepeat, 1>{}([&](auto k) {
443  static_for<0, MRepeat, 1>{}([&](auto m0) {
444  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
445  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
446  a_block_buf,
447  a_thread_desc_,
448  make_tuple(m0, I0, k, I0),
449  a_thread_buf);
450  });
451  static_for<0, NRepeat, 1>{}([&](auto n0) {
452  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
453  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
454  b_block_buf,
455  b_thread_desc_,
456  make_tuple(n0, I0, k, I0),
457  b_thread_buf);
458  });
459  });
460 
461  static_for<0, KRepeat, 1>{}([&](auto k0) {
462  static_for<0, MRepeat, 1>{}([&](auto m0) {
463  static_for<0, NRepeat, 1>{}([&](auto n0) {
466 
467  static_for<0, KPack, 1>{}([&](auto ik) {
468  a_thread_vec.template AsType<ComputeDataType>()(ik) =
469  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
470  make_tuple(m0, I0, k0, ik))>{}];
471  b_thread_vec.template AsType<ComputeDataType>()(ik) =
472  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
473  make_tuple(n0, I0, k0, ik))>{}];
474  });
475 
476  using mfma_input_type =
477  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
478 
479  constexpr index_t c_offset =
480  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
481 
482  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
483  b_thread_vec.template AsType<mfma_input_type>(),
484  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
485  });
486  });
487  });
488  }
489  else if constexpr(TailNum == TailNumber::Two)
490  {
491  LoopTailFunc(Number<2>{});
492  }
493  else if constexpr(TailNum == TailNumber::Three)
494  {
495  LoopTailFunc(Number<3>{});
496  }
497  else if constexpr(TailNum == TailNumber::Four)
498  {
499  LoopTailFunc(Number<4>{});
500  }
501  else if constexpr(TailNum == TailNumber::Five)
502  {
503  LoopTailFunc(Number<5>{});
504  }
505  else if constexpr(TailNum == TailNumber::Six)
506  {
507  LoopTailFunc(Number<6>{});
508  }
509  else if constexpr(TailNum == TailNumber::Seven)
510  {
511  LoopTailFunc(Number<7>{});
512  }
513  else if constexpr(TailNum == TailNumber::Full)
514  {
515  LoopTailFunc(Number<PrefetchStages>{});
516  }
517  }
518 
519  protected:
520  using Base::a_thread_copy_;
521  using Base::a_thread_desc_;
522  using Base::b_thread_copy_;
523  using Base::b_thread_desc_;
524  using Base::c_thread_desc_;
525 };
526 
527 template <index_t BlockSize,
528  typename ADataType,
529  typename BDataType,
530  typename ComputeDataType,
531  typename AccDataType,
532  typename ATileDesc,
533  typename BTileDesc,
534  typename AMmaTileDesc,
535  typename BMmaTileDesc,
536  index_t ABlockTransferSrcScalarPerVector,
537  index_t BBlockTransferSrcScalarPerVector,
538  index_t MPerBlock,
539  index_t NPerBlock,
540  index_t KPerBlock,
541  index_t MPerXDL,
542  index_t NPerXDL,
543  index_t MRepeat,
544  index_t NRepeat,
545  index_t KPack
546  // ,bool TransposeC //disable transposec right now...
547  >
549  BlockSize,
550  ADataType,
551  BDataType,
552  ComputeDataType,
553  AccDataType,
554  ATileDesc,
555  BTileDesc,
556  AMmaTileDesc,
557  BMmaTileDesc,
558  ABlockTransferSrcScalarPerVector,
559  BBlockTransferSrcScalarPerVector,
560  MPerBlock,
561  NPerBlock,
562  KPerBlock,
563  MPerXDL,
564  NPerXDL,
565  MRepeat,
566  NRepeat,
567  KPack>
569  ADataType,
570  BDataType,
571  ComputeDataType,
572  AccDataType,
573  ATileDesc,
574  BTileDesc,
575  AMmaTileDesc,
576  BMmaTileDesc,
577  ABlockTransferSrcScalarPerVector,
578  BBlockTransferSrcScalarPerVector,
579  MPerBlock,
580  NPerBlock,
581  KPerBlock,
582  MPerXDL,
583  NPerXDL,
584  MRepeat,
585  NRepeat,
586  KPack>
587 
588 {
590  ADataType,
591  BDataType,
592  ComputeDataType,
593  AccDataType,
594  ATileDesc,
595  BTileDesc,
596  AMmaTileDesc,
597  BMmaTileDesc,
598  ABlockTransferSrcScalarPerVector,
599  BBlockTransferSrcScalarPerVector,
600  MPerBlock,
601  NPerBlock,
602  KPerBlock,
603  MPerXDL,
604  NPerXDL,
605  MRepeat,
606  NRepeat,
607  KPack>;
608  using Base::A_K1;
609  using Base::B_K1;
610  using Base::I0;
611  using Base::I1;
612  using Base::KPerThread;
613  using Base::xdlops_gemm;
614 
615  using Base::CalculateCThreadOriginDataIndex;
616  using Base::CalculateCThreadOriginDataIndex8D;
617  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
618  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
619  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
620  using Base::GetCThreadBuffer;
621  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
622  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
623  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
624  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
625  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
626 
627  using Base::a_block_desc_m0_m1_m2_k;
628  using Base::b_block_desc_n0_n1_n2_k;
629  using Base::WaveSize;
630 
632  static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
633  static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
634 
635  static constexpr index_t WgpPerCU =
636  (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
637  static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
638  32768 / WgpPerCU,
639  (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
640  static constexpr index_t PrefetchStages =
641  FullMemBandPrefetchStages >= 2
642  ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
643  : 2;
644 
645  static constexpr index_t PrefillStages = 1;
646  static constexpr index_t GlobalBufferNum = PrefetchStages;
647 
648  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
649  {
650  return num_loop > PrefetchStages;
651  }
652 
653  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
654  {
655  if(num_loop % PrefetchStages == 1)
656  {
657  return TailNumber::One;
658  }
659  else if(num_loop % PrefetchStages == 2)
660  {
661  return TailNumber::Two;
662  }
663  else if(num_loop % PrefetchStages == 3)
664  {
665  return TailNumber::Three;
666  }
667  else if(num_loop % PrefetchStages == 4)
668  {
669  return TailNumber::Four;
670  }
671  else if(num_loop % PrefetchStages == 5)
672  {
673  return TailNumber::Five;
674  }
675  else if(num_loop % PrefetchStages == 6)
676  {
677  return TailNumber::Six;
678  }
679  else if(num_loop % PrefetchStages == 7)
680  {
681  return TailNumber::Seven;
682  }
683  else
684  {
685  return TailNumber::Full;
686  }
687  }
688 
689  template <bool HasMainLoop,
690  TailNumber TailNum,
691  typename AGridDesc,
692  typename ABlockDesc,
693  typename ABlockTransfer,
694  typename AGridBuffer,
695  typename ABlockBuffer,
696  typename ABlockTransferStep,
697  typename BGridDesc,
698  typename BBlockDesc,
699  typename BBlockTransfer,
700  typename BGridBuffer,
701  typename BBlockBuffer,
702  typename BBlockTransferStep,
703  typename CThreadBuffer>
704  __device__ void Run(const AGridDesc& a_grid_desc,
705  const ABlockDesc& a_block_desc,
706  ABlockTransfer& a_blockwise_copy,
707  const AGridBuffer& a_grid_buf,
708  ABlockBuffer& a_block_buf,
709  const ABlockTransferStep& a_block_copy_step,
710  const BGridDesc& b_grid_desc,
711  const BBlockDesc& b_block_desc,
712  BBlockTransfer& b_blockwise_copy,
713  const BGridBuffer& b_grid_buf,
714  BBlockBuffer& b_block_buf,
715  const BBlockTransferStep& b_block_copy_step,
716  CThreadBuffer& c_thread_buf,
717  index_t num_loop) const
718  {
719  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
720  a_thread_desc_.GetElementSpaceSize());
721  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
722  b_thread_desc_.GetElementSpaceSize());
723 
724  // Global prefetch 1
725  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
726  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
727 
728  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
729  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
730 
731  // Initialize C
732  c_thread_buf.Clear();
733 
734  // Local prefill 1
735  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
736  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
737 
738  // Global prefetch [2, PrefetchStages]
739  static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) {
740  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
741  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
742 
743  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
744  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
745  });
746 
747  // main body
748  if constexpr(HasMainLoop)
749  {
750  index_t i = 0;
751  do
752  {
753  static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) {
754  // -------------------------------------------------------------------------------------------
755  block_sync_lds();
756  static_for<0, KRepeat, 1>{}([&](auto k0) {
757  static_for<0, MRepeat, 1>{}([&](auto m0) {
758  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
759  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
760  a_block_buf,
761  a_thread_desc_,
762  make_tuple(m0, I0, k0, I0),
763  a_thread_buf);
764  });
765  static_for<0, NRepeat, 1>{}([&](auto n0) {
766  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
767  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
768  b_block_buf,
769  b_thread_desc_,
770  make_tuple(n0, I0, k0, I0),
771  b_thread_buf);
772  });
773  __builtin_amdgcn_sched_barrier(0);
774  // NOTE: Synchronize threads in a workgroup at the start of each MAC
775  // cluster, but except the first, as we can shorten non-MAC cluster a bit
776  // and there's no observable negative impact. The desired effect is waves in
777  // a workgroup executing MAC in sync. This avoids some out-of-sync waves
778  // hijacking MAC resource from other workgroups and reducing the chance of
779  // latency hiding by waiting for the rest of the workgroup at the eventual
780  // sync point.
781  if constexpr(k0.value != 0 || KRepeat == 1)
782  {
783  __builtin_amdgcn_s_barrier();
784  __builtin_amdgcn_sched_barrier(0);
785  }
786  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
787  static_for<0, MRepeat, 1>{}([&](auto m0) {
788  static_for<0, NRepeat, 1>{}([&](auto n0) {
791 
792  static_for<0, KPack, 1>{}([&](auto ik) {
793  a_thread_vec.template AsType<ComputeDataType>()(ik) =
794  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
795  make_tuple(m0, I0, k0, k_ + ik))>{}];
796  b_thread_vec.template AsType<ComputeDataType>()(ik) =
797  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
798  make_tuple(n0, I0, k0, k_ + ik))>{}];
799  });
800 
801  using mfma_input_type =
802  typename vector_type<ComputeDataType,
803  xdlops_gemm.K1PerXdlops>::type;
804 
805  constexpr index_t c_offset =
806  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
807 
808  // The block_sync_lds() here performs double duty:
809  // A) safeguard against data hazard because barrier from
810  // blockwise_gemm is moved here B) reduce VMEM FIFO congestion
811  // by applying small delays to different wavefronts It is
812  // performed near the end of MAC cluster to minimize lgkmcnt
813  // penalty
814  if constexpr(k0.value == KRepeat - 1 &&
815  k_.value == KPerInnerLoop - KPack &&
816  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
817  {
818  __builtin_amdgcn_sched_barrier(0);
819  block_sync_lds();
820  __builtin_amdgcn_sched_barrier(0);
821  }
822  xdlops_gemm.Run(
823  a_thread_vec.template AsType<mfma_input_type>(),
824  b_thread_vec.template AsType<mfma_input_type>(),
825  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
826  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
827  {
828  __builtin_amdgcn_sched_barrier(0);
829  __builtin_amdgcn_s_setprio(1);
830  __builtin_amdgcn_sched_barrier(0);
831  }
832  });
833  });
834  });
835  __builtin_amdgcn_sched_barrier(0);
836  __builtin_amdgcn_s_setprio(0);
837  __builtin_amdgcn_sched_barrier(0);
838  });
839 
840  // block_sync_lds();
841  a_blockwise_copy.RunWrite(
842  a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
843  b_blockwise_copy.RunWrite(
844  b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
845 
846  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
847  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
848 
849  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
850  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
851  });
852  i += PrefetchStages;
853  } while(i < (num_loop - PrefetchStages));
854  }
855 
856  // tail
857 
858  auto LoopTailFunc = [&](auto tail_num) {
859  static_for<1, tail_num, 1>{}([&](auto iprefetch) {
860  block_sync_lds();
861  static_for<0, KRepeat, 1>{}([&](auto k0) {
862  static_for<0, MRepeat, 1>{}([&](auto m0) {
863  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
864  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
865  a_block_buf,
866  a_thread_desc_,
867  make_tuple(m0, I0, k0, I0),
868  a_thread_buf);
869  });
870  static_for<0, NRepeat, 1>{}([&](auto n0) {
871  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
872  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
873  b_block_buf,
874  b_thread_desc_,
875  make_tuple(n0, I0, k0, I0),
876  b_thread_buf);
877  });
878 
879  __builtin_amdgcn_sched_barrier(0);
880  if constexpr(k0.value != 0 || KRepeat == 1)
881  {
882  __builtin_amdgcn_s_barrier();
883  __builtin_amdgcn_sched_barrier(0);
884  }
885  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
886  static_for<0, MRepeat, 1>{}([&](auto m0) {
887  static_for<0, NRepeat, 1>{}([&](auto n0) {
890 
891  static_for<0, KPack, 1>{}([&](auto ik) {
892  a_thread_vec.template AsType<ComputeDataType>()(ik) =
893  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
894  make_tuple(m0, I0, k0, k_ + ik))>{}];
895  b_thread_vec.template AsType<ComputeDataType>()(ik) =
896  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
897  make_tuple(n0, I0, k0, k_ + ik))>{}];
898  });
899 
900  using mfma_input_type =
901  typename vector_type<ComputeDataType,
902  xdlops_gemm.K1PerXdlops>::type;
903 
904  constexpr index_t c_offset =
905  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
906 
907  if constexpr(k0.value == KRepeat - 1 &&
908  k_.value == KPerInnerLoop - KPack &&
909  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
910  {
911  __builtin_amdgcn_sched_barrier(0);
912  block_sync_lds();
913  __builtin_amdgcn_sched_barrier(0);
914  }
915  xdlops_gemm.Run(
916  a_thread_vec.template AsType<mfma_input_type>(),
917  b_thread_vec.template AsType<mfma_input_type>(),
918  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
919  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
920  {
921  __builtin_amdgcn_sched_barrier(0);
922  __builtin_amdgcn_s_setprio(1);
923  __builtin_amdgcn_sched_barrier(0);
924  }
925  });
926  });
927  });
928  __builtin_amdgcn_sched_barrier(0);
929  __builtin_amdgcn_s_setprio(0);
930  __builtin_amdgcn_sched_barrier(0);
931  });
932 
933  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
934  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
935  });
936  block_sync_lds();
937  static_for<0, KRepeat, 1>{}([&](auto k0) {
938  static_for<0, MRepeat, 1>{}([&](auto m0) {
939  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
940  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
941  a_block_buf,
942  a_thread_desc_,
943  make_tuple(m0, I0, k0, I0),
944  a_thread_buf);
945  });
946  static_for<0, NRepeat, 1>{}([&](auto n0) {
947  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
948  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
949  b_block_buf,
950  b_thread_desc_,
951  make_tuple(n0, I0, k0, I0),
952  b_thread_buf);
953  });
954 
955  __builtin_amdgcn_sched_barrier(0);
956  if constexpr(k0.value != 0 || KRepeat == 1)
957  {
958  __builtin_amdgcn_s_barrier();
959  __builtin_amdgcn_sched_barrier(0);
960  }
961  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
962  static_for<0, MRepeat, 1>{}([&](auto m0) {
963  static_for<0, NRepeat, 1>{}([&](auto n0) {
966 
967  static_for<0, KPack, 1>{}([&](auto ik) {
968  a_thread_vec.template AsType<ComputeDataType>()(ik) =
969  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
970  make_tuple(m0, I0, k0, k_ + ik))>{}];
971  b_thread_vec.template AsType<ComputeDataType>()(ik) =
972  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
973  make_tuple(n0, I0, k0, k_ + ik))>{}];
974  });
975 
976  using mfma_input_type =
977  typename vector_type<ComputeDataType,
978  xdlops_gemm.K1PerXdlops>::type;
979 
980  constexpr index_t c_offset =
981  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
982 
983  if constexpr(k0.value == KRepeat - 1 &&
984  k_.value == KPerInnerLoop - KPack &&
985  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
986  {
987  __builtin_amdgcn_sched_barrier(0);
988  block_sync_lds();
989  __builtin_amdgcn_sched_barrier(0);
990  }
991  xdlops_gemm.Run(
992  a_thread_vec.template AsType<mfma_input_type>(),
993  b_thread_vec.template AsType<mfma_input_type>(),
994  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
995  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
996  {
997  __builtin_amdgcn_sched_barrier(0);
998  __builtin_amdgcn_s_setprio(1);
999  __builtin_amdgcn_sched_barrier(0);
1000  }
1001  });
1002  });
1003  });
1004  __builtin_amdgcn_sched_barrier(0);
1005  __builtin_amdgcn_s_setprio(0);
1006  __builtin_amdgcn_sched_barrier(0);
1007  });
1008  };
1009 
1010  if constexpr(TailNum == TailNumber::One)
1011  {
1012  block_sync_lds();
1013  static_for<0, KRepeat, 1>{}([&](auto k0) {
1014  static_for<0, MRepeat, 1>{}([&](auto m0) {
1015  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
1016  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
1017  a_block_buf,
1018  a_thread_desc_,
1019  make_tuple(m0, I0, k0, I0),
1020  a_thread_buf);
1021  });
1022  static_for<0, NRepeat, 1>{}([&](auto n0) {
1023  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
1024  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
1025  b_block_buf,
1026  b_thread_desc_,
1027  make_tuple(n0, I0, k0, I0),
1028  b_thread_buf);
1029  });
1030 
1031  __builtin_amdgcn_sched_barrier(0);
1032  if constexpr(k0.value != 0 || KRepeat == 1)
1033  {
1034  __builtin_amdgcn_s_barrier();
1035  __builtin_amdgcn_sched_barrier(0);
1036  }
1037  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
1038  static_for<0, MRepeat, 1>{}([&](auto m0) {
1039  static_for<0, NRepeat, 1>{}([&](auto n0) {
1042 
1043  static_for<0, KPack, 1>{}([&](auto ik) {
1044  a_thread_vec.template AsType<ComputeDataType>()(ik) =
1045  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1046  make_tuple(m0, I0, k0, k_ + ik))>{}];
1047  b_thread_vec.template AsType<ComputeDataType>()(ik) =
1048  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1049  make_tuple(n0, I0, k0, k_ + ik))>{}];
1050  });
1051 
1052  using mfma_input_type =
1053  typename vector_type<ComputeDataType,
1054  xdlops_gemm.K1PerXdlops>::type;
1055 
1056  constexpr index_t c_offset =
1057  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1058 
1059  if constexpr(k0.value == KRepeat - 1 &&
1060  k_.value == KPerInnerLoop - KPack &&
1061  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
1062  {
1063  __builtin_amdgcn_sched_barrier(0);
1064  block_sync_lds();
1065  __builtin_amdgcn_sched_barrier(0);
1066  }
1067  xdlops_gemm.Run(
1068  a_thread_vec.template AsType<mfma_input_type>(),
1069  b_thread_vec.template AsType<mfma_input_type>(),
1070  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1071  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
1072  {
1073  __builtin_amdgcn_sched_barrier(0);
1074  __builtin_amdgcn_s_setprio(1);
1075  __builtin_amdgcn_sched_barrier(0);
1076  }
1077  });
1078  });
1079  });
1080  __builtin_amdgcn_sched_barrier(0);
1081  __builtin_amdgcn_s_setprio(0);
1082  __builtin_amdgcn_sched_barrier(0);
1083  });
1084  }
1085  else if constexpr(TailNum == TailNumber::Two)
1086  {
1087  LoopTailFunc(Number<2>{});
1088  }
1089  else if constexpr(TailNum == TailNumber::Three)
1090  {
1091  LoopTailFunc(Number<3>{});
1092  }
1093  else if constexpr(TailNum == TailNumber::Four)
1094  {
1095  LoopTailFunc(Number<4>{});
1096  }
1097  else if constexpr(TailNum == TailNumber::Five)
1098  {
1099  LoopTailFunc(Number<5>{});
1100  }
1101  else if constexpr(TailNum == TailNumber::Six)
1102  {
1103  LoopTailFunc(Number<6>{});
1104  }
1105  else if constexpr(TailNum == TailNumber::Seven)
1106  {
1107  LoopTailFunc(Number<7>{});
1108  }
1109  else if constexpr(TailNum == TailNumber::Full)
1110  {
1111  LoopTailFunc(Number<PrefetchStages>{});
1112  }
1113  }
1114 
1115  protected:
1116  // K->M loopover
1117  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
1118  make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
1119  make_tuple(Number<KPerInnerLoop>{},
1120  Number<KRepeat * MRepeat * KPerInnerLoop>{},
1121  Number<MRepeat * KPerInnerLoop>{},
1122  I1));
1123 
1124  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
1125  make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
1126  make_tuple(Number<KPerInnerLoop>{},
1127  Number<KRepeat * NRepeat * KPerInnerLoop>{},
1128  Number<NRepeat * KPerInnerLoop>{},
1129  I1));
1130 
1132  ComputeDataType,
1133  decltype(a_block_desc_m0_m1_m2_k),
1134  decltype(a_thread_desc_),
1137  3,
1138  A_K1,
1139  A_K1>;
1140 
1142  ComputeDataType,
1143  decltype(b_block_desc_n0_n1_n2_k),
1144  decltype(b_thread_desc_),
1147  3,
1148  B_K1,
1149  B_K1>;
1150 
1151  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
1152  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
1153  using Base::c_thread_desc_;
1154 };
1155 
1156 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:208
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:704
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:213
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:37
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10