/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp Source File
blockwise_gemm_pipeline_xdlops_v1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Naive pipeline with lowest resource request per WGP
11 // GlobalPrefetchStages: 1
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 0
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::I0;
122  using Base::KRepeat;
123  using Base::xdlops_gemm;
124 
125  using Base::CalculateCThreadOriginDataIndex;
126  using Base::CalculateCThreadOriginDataIndex8D;
127  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
128  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
129  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
130  using Base::GetCThreadBuffer;
131  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
132  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
133  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
134  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
136 
137  using Base::a_block_desc_m0_m1_m2_k;
138  using Base::b_block_desc_n0_n1_n2_k;
139 
140  using Base::AMmaKStride;
141  using Base::BMmaKStride;
142 
144 
145  static constexpr index_t PrefetchStages = 1;
146  static constexpr index_t PrefillStages = 1;
147  static constexpr index_t GlobalBufferNum = 1;
148 
149  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
150  {
151  return num_loop > PrefetchStages;
152  }
153 
154  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
155  {
156  ignore = num_loop;
157  return TailNumber::Full;
158  }
159 
160  template <bool HasMainLoop,
161  TailNumber TailNum,
162  typename AGridDesc,
163  typename ABlockDesc,
164  typename ABlockTransfer,
165  typename AGridBuffer,
166  typename ABlockBuffer,
167  typename ABlockTransferStep,
168  typename BGridDesc,
169  typename BBlockDesc,
170  typename BBlockTransfer,
171  typename BGridBuffer,
172  typename BBlockBuffer,
173  typename BBlockTransferStep,
174  typename CThreadBuffer>
175  __device__ void Run(const AGridDesc& a_grid_desc,
176  const ABlockDesc& a_block_desc,
177  ABlockTransfer& a_blockwise_copy,
178  const AGridBuffer& a_grid_buf,
179  ABlockBuffer& a_block_buf,
180  const ABlockTransferStep& a_block_copy_step,
181  const BGridDesc& b_grid_desc,
182  const BBlockDesc& b_block_desc,
183  BBlockTransfer& b_blockwise_copy,
184  const BGridBuffer& b_grid_buf,
185  BBlockBuffer& b_block_buf,
186  const BBlockTransferStep& b_block_copy_step,
187  CThreadBuffer& c_thread_buf,
188  index_t num_loop) const
189  {
190  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
191  a_thread_desc_.GetElementSpaceSize());
192  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
193  b_thread_desc_.GetElementSpaceSize());
194 
195  // Global prefetch 1
196  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
197  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
198 
199  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
200  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
201 
202  // Local prefill 1
203  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
204  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
205 
206  // Initialize C
207  c_thread_buf.Clear();
208 
209  // main body
210  if constexpr(HasMainLoop)
211  {
212  index_t i = 0;
213  do
214  {
215  // -------------------------------------------------------------------------------------------
216  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
217  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
218 
219  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
220  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
221 
222  block_sync_lds();
223  static_for<0, KRepeat, 1>{}([&](auto k) {
224  static_for<0, MRepeat, 1>{}([&](auto m0) {
225  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
226  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
227  a_block_buf,
228  a_thread_desc_,
229  make_tuple(m0, I0, k, I0),
230  a_thread_buf);
231  static_for<0, NRepeat, 1>{}([&](auto n0) {
232  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
233  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
234  b_block_buf,
235  b_thread_desc_,
236  make_tuple(n0, I0, k, I0),
237  b_thread_buf);
238  });
239  });
240  });
241 
242  static_for<0, KRepeat, 1>{}([&](auto k0) {
243  static_for<0, MRepeat, 1>{}([&](auto m0) {
244  static_for<0, NRepeat, 1>{}([&](auto n0) {
247 
248  static_for<0, KPack, 1>{}([&](auto ik) {
249  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
250  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
251  make_tuple(m0, I0, k0, ik))>{}];
252  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
253  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
254  make_tuple(n0, I0, k0, ik))>{}];
255  });
256 
257  using mfma_input_type =
259  xdlops_gemm.K1PerXdlops>::type;
260 
261  constexpr index_t c_offset =
262  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
263 
264  xdlops_gemm.Run(
265  a_thread_vec.template AsType<mfma_input_type>(),
266  b_thread_vec.template AsType<mfma_input_type>(),
267  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
268  });
269  });
270  });
271 
272  block_sync_lds();
273  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
274  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
275 
276  i += 1;
277  } while(i < (num_loop - 1));
278  }
279 
280  // tail
281  if constexpr(TailNum == TailNumber::Full)
282  {
283  block_sync_lds();
284  static_for<0, KRepeat, 1>{}([&](auto k) {
285  static_for<0, MRepeat, 1>{}([&](auto m0) {
286  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
287  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
288  a_block_buf,
289  a_thread_desc_,
290  make_tuple(m0, I0, k, I0),
291  a_thread_buf);
292  static_for<0, NRepeat, 1>{}([&](auto n0) {
293  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
294  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
295  b_block_buf,
296  b_thread_desc_,
297  make_tuple(n0, I0, k, I0),
298  b_thread_buf);
299  });
300  });
301  });
302 
303  static_for<0, KRepeat, 1>{}([&](auto k0) {
304  static_for<0, MRepeat, 1>{}([&](auto m0) {
305  static_for<0, NRepeat, 1>{}([&](auto n0) {
308 
309  static_for<0, KPack, 1>{}([&](auto ik) {
310  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
311  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
312  make_tuple(m0, I0, k0, ik))>{}];
313  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
314  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
315  make_tuple(n0, I0, k0, ik))>{}];
316  });
317 
318  using mfma_input_type =
319  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
320 
321  constexpr index_t c_offset =
322  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
323 
324  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
325  b_thread_vec.template AsType<mfma_input_type>(),
326  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
327  });
328  });
329  });
330  }
331  }
332 
333  protected:
334  using Base::a_thread_copy_;
335  using Base::a_thread_desc_;
336  using Base::b_thread_copy_;
337  using Base::b_thread_desc_;
338  using Base::c_thread_desc_;
339 };
340 
341 template <index_t BlockSize,
342  typename ADataType,
343  typename BDataType,
344  typename ComputeDataType,
345  typename AccDataType,
346  typename ATileDesc,
347  typename BTileDesc,
348  typename AMmaTileDesc,
349  typename BMmaTileDesc,
350  index_t ABlockTransferSrcScalarPerVector,
351  index_t BBlockTransferSrcScalarPerVector,
352  index_t MPerBlock,
353  index_t NPerBlock,
354  index_t KPerBlock,
355  index_t MPerXDL,
356  index_t NPerXDL,
357  index_t MRepeat,
358  index_t NRepeat,
359  index_t KPack
360  // ,bool TransposeC //disable transposec right now...
361  >
363  BlockSize,
364  ADataType,
365  BDataType,
366  ComputeDataType,
367  AccDataType,
368  ATileDesc,
369  BTileDesc,
370  AMmaTileDesc,
371  BMmaTileDesc,
372  ABlockTransferSrcScalarPerVector,
373  BBlockTransferSrcScalarPerVector,
374  MPerBlock,
375  NPerBlock,
376  KPerBlock,
377  MPerXDL,
378  NPerXDL,
379  MRepeat,
380  NRepeat,
381  KPack>
383  ADataType,
384  BDataType,
385  ComputeDataType,
386  AccDataType,
387  ATileDesc,
388  BTileDesc,
389  AMmaTileDesc,
390  BMmaTileDesc,
391  ABlockTransferSrcScalarPerVector,
392  BBlockTransferSrcScalarPerVector,
393  MPerBlock,
394  NPerBlock,
395  KPerBlock,
396  MPerXDL,
397  NPerXDL,
398  MRepeat,
399  NRepeat,
400  KPack>
401 
402 {
404  ADataType,
405  BDataType,
406  ComputeDataType,
407  AccDataType,
408  ATileDesc,
409  BTileDesc,
410  AMmaTileDesc,
411  BMmaTileDesc,
412  ABlockTransferSrcScalarPerVector,
413  BBlockTransferSrcScalarPerVector,
414  MPerBlock,
415  NPerBlock,
416  KPerBlock,
417  MPerXDL,
418  NPerXDL,
419  MRepeat,
420  NRepeat,
421  KPack>;
422  using Base::A_K1;
423  using Base::B_K1;
424  using Base::I0;
425  using Base::I1;
426  using Base::KPerThread;
427  using Base::xdlops_gemm;
428 
429  using Base::CalculateCThreadOriginDataIndex;
430  using Base::CalculateCThreadOriginDataIndex8D;
431  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
432  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
433  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
434  using Base::GetCThreadBuffer;
435  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
436  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
437  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
438  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
439  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
440 
441  using Base::a_block_desc_m0_m1_m2_k;
442  using Base::b_block_desc_n0_n1_n2_k;
443 
445 
447  static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
448  static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
449  static constexpr index_t PrefetchStages = 1;
450  static constexpr index_t PrefillStages = 1;
451  static constexpr index_t GlobalBufferNum = 1;
452  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
453  {
454  return num_loop > PrefetchStages;
455  }
456 
457  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
458  {
459  ignore = num_loop;
460  return TailNumber::Full;
461  }
462 
463  template <bool HasMainLoop,
464  TailNumber TailNum,
465  typename AGridDesc,
466  typename ABlockDesc,
467  typename ABlockTransfer,
468  typename AGridBuffer,
469  typename ABlockBuffer,
470  typename ABlockTransferStep,
471  typename BGridDesc,
472  typename BBlockDesc,
473  typename BBlockTransfer,
474  typename BGridBuffer,
475  typename BBlockBuffer,
476  typename BBlockTransferStep,
477  typename CThreadBuffer>
478  __device__ void Run(const AGridDesc& a_grid_desc,
479  const ABlockDesc& a_block_desc,
480  ABlockTransfer& a_blockwise_copy,
481  const AGridBuffer& a_grid_buf,
482  ABlockBuffer& a_block_buf,
483  const ABlockTransferStep& a_block_copy_step,
484  const BGridDesc& b_grid_desc,
485  const BBlockDesc& b_block_desc,
486  BBlockTransfer& b_blockwise_copy,
487  const BGridBuffer& b_grid_buf,
488  BBlockBuffer& b_block_buf,
489  const BBlockTransferStep& b_block_copy_step,
490  CThreadBuffer& c_thread_buf,
491  index_t num_loop) const
492  {
493  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
494  a_thread_desc_.GetElementSpaceSize());
495  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
496  b_thread_desc_.GetElementSpaceSize());
497 
498  // Global prefetch 1
499  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
500  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
501 
502  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
503  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
504 
505  // Local prefill 1
506  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
507  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
508 
509  // Initialize C
510  c_thread_buf.Clear();
511 
512  // main body
513  if constexpr(HasMainLoop)
514  {
515  index_t i = 0;
516  do
517  {
518  // -------------------------------------------------------------------------------------------
519  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
520  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
521 
522  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
523  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
524 
525  block_sync_lds();
526  static_for<0, KRepeat, 1>{}([&](auto k0) {
527  static_for<0, MRepeat, 1>{}([&](auto m0) {
528  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
529  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
530  a_block_buf,
531  a_thread_desc_,
532  make_tuple(m0, I0, k0, I0),
533  a_thread_buf);
534  static_for<0, NRepeat, 1>{}([&](auto n0) {
535  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
536  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
537  b_block_buf,
538  b_thread_desc_,
539  make_tuple(n0, I0, k0, I0),
540  b_thread_buf);
541  });
542  });
543  __builtin_amdgcn_sched_barrier(0);
544  // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
545  // but except the first, as we can shorten non-MAC cluster a bit and there's no
546  // observable negative impact. The desired effect is waves in a workgroup
547  // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
548  // resource from other workgroups and reducing the chance of latency hiding by
549  // waiting for the rest of the workgroup at the eventual sync point.
550  if constexpr(k0.value != 0 || KRepeat == 1)
551  {
552  __builtin_amdgcn_s_barrier();
553  __builtin_amdgcn_sched_barrier(0);
554  }
555  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
556  static_for<0, MRepeat, 1>{}([&](auto m0) {
557  static_for<0, NRepeat, 1>{}([&](auto n0) {
560 
561  static_for<0, KPack, 1>{}([&](auto ik) {
562  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
563  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
564  make_tuple(m0, I0, k0, k_ + ik))>{}];
565  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
566  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
567  make_tuple(n0, I0, k0, k_ + ik))>{}];
568  });
569 
570  using mfma_input_type =
572  xdlops_gemm.K1PerXdlops>::type;
573 
574  constexpr index_t c_offset =
575  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
576 
577  // The block_sync_lds() here performs double duty:
578  // A) safeguard against data hazard because barrier from
579  // blockwise_gemm is moved here B) reduce VMEM FIFO congestion by
580  // applying small delays to different wavefronts It is performed
581  // near the end of MAC cluster to minimize lgkmcnt penalty
582  if constexpr(k0.value == KRepeat - 1 &&
583  k_.value == KPerInnerLoop - KPack &&
584  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
585  {
586  __builtin_amdgcn_sched_barrier(0);
587  block_sync_lds();
588  __builtin_amdgcn_sched_barrier(0);
589  }
590  xdlops_gemm.Run(
591  a_thread_vec.template AsType<mfma_input_type>(),
592  b_thread_vec.template AsType<mfma_input_type>(),
593  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
594  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
595  {
596  __builtin_amdgcn_sched_barrier(0);
597  __builtin_amdgcn_s_setprio(1);
598  __builtin_amdgcn_sched_barrier(0);
599  }
600  });
601  });
602  });
603  __builtin_amdgcn_sched_barrier(0);
604  __builtin_amdgcn_s_setprio(0);
605  __builtin_amdgcn_sched_barrier(0);
606  });
607 
608  // block_sync_lds();
609  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
610  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
611 
612  i += 1;
613  } while(i < (num_loop - 1));
614  }
615 
616  // tail
617  if constexpr(TailNum == TailNumber::Full)
618  {
619  block_sync_lds();
620  static_for<0, KRepeat, 1>{}([&](auto k0) {
621  static_for<0, MRepeat, 1>{}([&](auto m0) {
622  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
623  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
624  a_block_buf,
625  a_thread_desc_,
626  make_tuple(m0, I0, k0, I0),
627  a_thread_buf);
628  static_for<0, NRepeat, 1>{}([&](auto n0) {
629  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
630  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
631  b_block_buf,
632  b_thread_desc_,
633  make_tuple(n0, I0, k0, I0),
634  b_thread_buf);
635  });
636  });
637 
638  __builtin_amdgcn_sched_barrier(0);
639  if constexpr(k0.value != 0 || KRepeat == 1)
640  {
641  __builtin_amdgcn_s_barrier();
642  __builtin_amdgcn_sched_barrier(0);
643  }
644  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
645  static_for<0, MRepeat, 1>{}([&](auto m0) {
646  static_for<0, NRepeat, 1>{}([&](auto n0) {
649 
650  static_for<0, KPack, 1>{}([&](auto ik) {
651  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
652  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
653  make_tuple(m0, I0, k0, k_ + ik))>{}];
654  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
655  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
656  make_tuple(n0, I0, k0, k_ + ik))>{}];
657  });
658 
659  using mfma_input_type =
661  xdlops_gemm.K1PerXdlops>::type;
662 
663  constexpr index_t c_offset =
664  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
665 
666  if constexpr(k0.value == KRepeat - 1 &&
667  k_.value == KPerInnerLoop - KPack &&
668  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
669  {
670  __builtin_amdgcn_sched_barrier(0);
671  block_sync_lds();
672  __builtin_amdgcn_sched_barrier(0);
673  }
674  xdlops_gemm.Run(
675  a_thread_vec.template AsType<mfma_input_type>(),
676  b_thread_vec.template AsType<mfma_input_type>(),
677  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
678  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
679  {
680  __builtin_amdgcn_sched_barrier(0);
681  __builtin_amdgcn_s_setprio(1);
682  __builtin_amdgcn_sched_barrier(0);
683  }
684  });
685  });
686  });
687  __builtin_amdgcn_sched_barrier(0);
688  __builtin_amdgcn_s_setprio(0);
689  __builtin_amdgcn_sched_barrier(0);
690  });
691  }
692  }
693 
694  protected:
695  // K->M loopover
696  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
697  make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
698  make_tuple(Number<KPerInnerLoop>{},
699  Number<KRepeat * MRepeat * KPerInnerLoop>{},
700  Number<MRepeat * KPerInnerLoop>{},
701  I1));
702 
703  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
704  make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
705  make_tuple(Number<KPerInnerLoop>{},
706  Number<KRepeat * NRepeat * KPerInnerLoop>{},
707  Number<NRepeat * KPerInnerLoop>{},
708  I1));
709 
712  decltype(a_block_desc_m0_m1_m2_k),
713  decltype(a_thread_desc_),
716  3,
717  A_K1,
718  A_K1>;
719 
722  decltype(b_block_desc_n0_n1_n2_k),
723  decltype(b_thread_desc_),
726  3,
727  B_K1,
728  B_K1>;
729 
730  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
731  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
732  using Base::c_thread_desc_;
733 };
734 
735 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:209
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:478
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:175
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:37
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10