/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp Source File
blockwise_gemm_pipeline_xdlops_v1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Naive pipeline with lowest resource request per WGP
11 // GlobalPrefetchStages: 1
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 0
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::I0;
122  using Base::KRepeat;
123  using Base::xdlops_gemm;
124 
125  using Base::CalculateCThreadOriginDataIndex;
126  using Base::CalculateCThreadOriginDataIndex8D;
127  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
128  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
129  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
130  using Base::GetCThreadBuffer;
131  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
132  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
133  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
134  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
136 
137  using Base::a_block_desc_m0_m1_m2_k;
138  using Base::b_block_desc_n0_n1_n2_k;
139 
140  using Base::AMmaKStride;
141  using Base::BMmaKStride;
142 
143  static constexpr index_t PrefetchStages = 1;
144  static constexpr index_t PrefillStages = 1;
145  static constexpr index_t GlobalBufferNum = 1;
146 
147  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
148  {
149  return num_loop > PrefetchStages;
150  }
151 
152  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
153  {
154  ignore = num_loop;
155  return TailNumber::Full;
156  }
157 
158  template <bool HasMainLoop,
159  TailNumber TailNum,
160  typename AGridDesc,
161  typename ABlockDesc,
162  typename ABlockTransfer,
163  typename AGridBuffer,
164  typename ABlockBuffer,
165  typename ABlockTransferStep,
166  typename BGridDesc,
167  typename BBlockDesc,
168  typename BBlockTransfer,
169  typename BGridBuffer,
170  typename BBlockBuffer,
171  typename BBlockTransferStep,
172  typename CThreadBuffer>
173  __device__ void Run(const AGridDesc& a_grid_desc,
174  const ABlockDesc& a_block_desc,
175  ABlockTransfer& a_blockwise_copy,
176  const AGridBuffer& a_grid_buf,
177  ABlockBuffer& a_block_buf,
178  const ABlockTransferStep& a_block_copy_step,
179  const BGridDesc& b_grid_desc,
180  const BBlockDesc& b_block_desc,
181  BBlockTransfer& b_blockwise_copy,
182  const BGridBuffer& b_grid_buf,
183  BBlockBuffer& b_block_buf,
184  const BBlockTransferStep& b_block_copy_step,
185  CThreadBuffer& c_thread_buf,
186  index_t num_loop) const
187  {
188  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
189  a_thread_desc_.GetElementSpaceSize());
190  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
191  b_thread_desc_.GetElementSpaceSize());
192 
193  // Global prefetch 1
194  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
195  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
196 
197  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
198  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
199 
200  // Local prefill 1
201  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
202  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
203 
204  // Initialize C
205  c_thread_buf.Clear();
206 
207  // main body
208  if constexpr(HasMainLoop)
209  {
210  index_t i = 0;
211  do
212  {
213  // -------------------------------------------------------------------------------------------
214  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
215  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
216 
217  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
218  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
219 
220  block_sync_lds();
221  static_for<0, KRepeat, 1>{}([&](auto k) {
222  static_for<0, MRepeat, 1>{}([&](auto m0) {
223  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
224  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
225  a_block_buf,
226  a_thread_desc_,
227  make_tuple(m0, I0, k, I0),
228  a_thread_buf);
229  static_for<0, NRepeat, 1>{}([&](auto n0) {
230  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
231  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
232  b_block_buf,
233  b_thread_desc_,
234  make_tuple(n0, I0, k, I0),
235  b_thread_buf);
236  });
237  });
238  });
239 
240  static_for<0, KRepeat, 1>{}([&](auto k0) {
241  static_for<0, MRepeat, 1>{}([&](auto m0) {
242  static_for<0, NRepeat, 1>{}([&](auto n0) {
245 
246  static_for<0, KPack, 1>{}([&](auto ik) {
247  a_thread_vec.template AsType<ComputeDataType>()(ik) =
248  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
249  make_tuple(m0, I0, k0, ik))>{}];
250  b_thread_vec.template AsType<ComputeDataType>()(ik) =
251  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
252  make_tuple(n0, I0, k0, ik))>{}];
253  });
254 
255  using mfma_input_type =
256  typename vector_type<ComputeDataType,
257  xdlops_gemm.K1PerXdlops>::type;
258 
259  constexpr index_t c_offset =
260  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
261 
262  xdlops_gemm.Run(
263  a_thread_vec.template AsType<mfma_input_type>(),
264  b_thread_vec.template AsType<mfma_input_type>(),
265  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
266  });
267  });
268  });
269 
270  block_sync_lds();
271  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
272  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
273 
274  i += 1;
275  } while(i < (num_loop - 1));
276  }
277 
278  // tail
279  if constexpr(TailNum == TailNumber::Full)
280  {
281  block_sync_lds();
282  static_for<0, KRepeat, 1>{}([&](auto k) {
283  static_for<0, MRepeat, 1>{}([&](auto m0) {
284  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
285  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
286  a_block_buf,
287  a_thread_desc_,
288  make_tuple(m0, I0, k, I0),
289  a_thread_buf);
290  static_for<0, NRepeat, 1>{}([&](auto n0) {
291  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
292  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
293  b_block_buf,
294  b_thread_desc_,
295  make_tuple(n0, I0, k, I0),
296  b_thread_buf);
297  });
298  });
299  });
300 
301  static_for<0, KRepeat, 1>{}([&](auto k0) {
302  static_for<0, MRepeat, 1>{}([&](auto m0) {
303  static_for<0, NRepeat, 1>{}([&](auto n0) {
306 
307  static_for<0, KPack, 1>{}([&](auto ik) {
308  a_thread_vec.template AsType<ComputeDataType>()(ik) =
309  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
310  make_tuple(m0, I0, k0, ik))>{}];
311  b_thread_vec.template AsType<ComputeDataType>()(ik) =
312  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
313  make_tuple(n0, I0, k0, ik))>{}];
314  });
315 
316  using mfma_input_type =
317  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
318 
319  constexpr index_t c_offset =
320  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
321 
322  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
323  b_thread_vec.template AsType<mfma_input_type>(),
324  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
325  });
326  });
327  });
328  }
329  }
330 
331  protected:
332  using Base::a_thread_copy_;
333  using Base::a_thread_desc_;
334  using Base::b_thread_copy_;
335  using Base::b_thread_desc_;
336  using Base::c_thread_desc_;
337 };
338 
339 template <index_t BlockSize,
340  typename ADataType,
341  typename BDataType,
342  typename ComputeDataType,
343  typename AccDataType,
344  typename ATileDesc,
345  typename BTileDesc,
346  typename AMmaTileDesc,
347  typename BMmaTileDesc,
348  index_t ABlockTransferSrcScalarPerVector,
349  index_t BBlockTransferSrcScalarPerVector,
350  index_t MPerBlock,
351  index_t NPerBlock,
352  index_t KPerBlock,
353  index_t MPerXDL,
354  index_t NPerXDL,
355  index_t MRepeat,
356  index_t NRepeat,
357  index_t KPack
358  // ,bool TransposeC //disable transposec right now...
359  >
361  BlockSize,
362  ADataType,
363  BDataType,
364  ComputeDataType,
365  AccDataType,
366  ATileDesc,
367  BTileDesc,
368  AMmaTileDesc,
369  BMmaTileDesc,
370  ABlockTransferSrcScalarPerVector,
371  BBlockTransferSrcScalarPerVector,
372  MPerBlock,
373  NPerBlock,
374  KPerBlock,
375  MPerXDL,
376  NPerXDL,
377  MRepeat,
378  NRepeat,
379  KPack>
381  ADataType,
382  BDataType,
383  ComputeDataType,
384  AccDataType,
385  ATileDesc,
386  BTileDesc,
387  AMmaTileDesc,
388  BMmaTileDesc,
389  ABlockTransferSrcScalarPerVector,
390  BBlockTransferSrcScalarPerVector,
391  MPerBlock,
392  NPerBlock,
393  KPerBlock,
394  MPerXDL,
395  NPerXDL,
396  MRepeat,
397  NRepeat,
398  KPack>
399 
400 {
402  ADataType,
403  BDataType,
404  ComputeDataType,
405  AccDataType,
406  ATileDesc,
407  BTileDesc,
408  AMmaTileDesc,
409  BMmaTileDesc,
410  ABlockTransferSrcScalarPerVector,
411  BBlockTransferSrcScalarPerVector,
412  MPerBlock,
413  NPerBlock,
414  KPerBlock,
415  MPerXDL,
416  NPerXDL,
417  MRepeat,
418  NRepeat,
419  KPack>;
420  using Base::A_K1;
421  using Base::B_K1;
422  using Base::I0;
423  using Base::I1;
424  using Base::KPerThread;
425  using Base::xdlops_gemm;
426 
427  using Base::CalculateCThreadOriginDataIndex;
428  using Base::CalculateCThreadOriginDataIndex8D;
429  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
430  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
431  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
432  using Base::GetCThreadBuffer;
433  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
434  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
435  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
436  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
437  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
438 
439  using Base::a_block_desc_m0_m1_m2_k;
440  using Base::b_block_desc_n0_n1_n2_k;
441 
443  static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
444  static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
445  static constexpr index_t PrefetchStages = 1;
446  static constexpr index_t PrefillStages = 1;
447  static constexpr index_t GlobalBufferNum = 1;
448  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
449  {
450  return num_loop > PrefetchStages;
451  }
452 
453  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
454  {
455  ignore = num_loop;
456  return TailNumber::Full;
457  }
458 
459  template <bool HasMainLoop,
460  TailNumber TailNum,
461  typename AGridDesc,
462  typename ABlockDesc,
463  typename ABlockTransfer,
464  typename AGridBuffer,
465  typename ABlockBuffer,
466  typename ABlockTransferStep,
467  typename BGridDesc,
468  typename BBlockDesc,
469  typename BBlockTransfer,
470  typename BGridBuffer,
471  typename BBlockBuffer,
472  typename BBlockTransferStep,
473  typename CThreadBuffer>
474  __device__ void Run(const AGridDesc& a_grid_desc,
475  const ABlockDesc& a_block_desc,
476  ABlockTransfer& a_blockwise_copy,
477  const AGridBuffer& a_grid_buf,
478  ABlockBuffer& a_block_buf,
479  const ABlockTransferStep& a_block_copy_step,
480  const BGridDesc& b_grid_desc,
481  const BBlockDesc& b_block_desc,
482  BBlockTransfer& b_blockwise_copy,
483  const BGridBuffer& b_grid_buf,
484  BBlockBuffer& b_block_buf,
485  const BBlockTransferStep& b_block_copy_step,
486  CThreadBuffer& c_thread_buf,
487  index_t num_loop) const
488  {
489  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
490  a_thread_desc_.GetElementSpaceSize());
491  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
492  b_thread_desc_.GetElementSpaceSize());
493 
494  // Global prefetch 1
495  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
496  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
497 
498  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
499  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
500 
501  // Local prefill 1
502  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
503  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
504 
505  // Initialize C
506  c_thread_buf.Clear();
507 
508  // main body
509  if constexpr(HasMainLoop)
510  {
511  index_t i = 0;
512  do
513  {
514  // -------------------------------------------------------------------------------------------
515  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
516  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
517 
518  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
519  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
520 
521  block_sync_lds();
522  static_for<0, KRepeat, 1>{}([&](auto k0) {
523  static_for<0, MRepeat, 1>{}([&](auto m0) {
524  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
525  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
526  a_block_buf,
527  a_thread_desc_,
528  make_tuple(m0, I0, k0, I0),
529  a_thread_buf);
530  static_for<0, NRepeat, 1>{}([&](auto n0) {
531  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
532  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
533  b_block_buf,
534  b_thread_desc_,
535  make_tuple(n0, I0, k0, I0),
536  b_thread_buf);
537  });
538  });
539  __builtin_amdgcn_sched_barrier(0);
540  // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
541  // but except the first, as we can shorten non-MAC cluster a bit and there's no
542  // observable negative impact. The desired effect is waves in a workgroup
543  // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
544  // resource from other workgroups and reducing the chance of latency hiding by
545  // waiting for the rest of the workgroup at the eventual sync point.
546  if constexpr(k0.value != 0 || KRepeat == 1)
547  {
548  __builtin_amdgcn_s_barrier();
549  __builtin_amdgcn_sched_barrier(0);
550  }
551  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
552  static_for<0, MRepeat, 1>{}([&](auto m0) {
553  static_for<0, NRepeat, 1>{}([&](auto n0) {
556 
557  static_for<0, KPack, 1>{}([&](auto ik) {
558  a_thread_vec.template AsType<ComputeDataType>()(ik) =
559  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
560  make_tuple(m0, I0, k0, k_ + ik))>{}];
561  b_thread_vec.template AsType<ComputeDataType>()(ik) =
562  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
563  make_tuple(n0, I0, k0, k_ + ik))>{}];
564  });
565 
566  using mfma_input_type =
567  typename vector_type<ComputeDataType,
568  xdlops_gemm.K1PerXdlops>::type;
569 
570  constexpr index_t c_offset =
571  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
572 
573  // The block_sync_lds() here performs double duty:
574  // A) safeguard against data hazard because barrier from
575  // blockwise_gemm is moved here B) reduce VMEM FIFO congestion by
576  // applying small delays to different wavefronts It is performed
577  // near the end of MAC cluster to minimize lgkmcnt penalty
578  if constexpr(k0.value == KRepeat - 1 &&
579  k_.value == KPerInnerLoop - KPack &&
580  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
581  {
582  __builtin_amdgcn_sched_barrier(0);
583  block_sync_lds();
584  __builtin_amdgcn_sched_barrier(0);
585  }
586  xdlops_gemm.Run(
587  a_thread_vec.template AsType<mfma_input_type>(),
588  b_thread_vec.template AsType<mfma_input_type>(),
589  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
590  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
591  {
592  __builtin_amdgcn_sched_barrier(0);
593  __builtin_amdgcn_s_setprio(1);
594  __builtin_amdgcn_sched_barrier(0);
595  }
596  });
597  });
598  });
599  __builtin_amdgcn_sched_barrier(0);
600  __builtin_amdgcn_s_setprio(0);
601  __builtin_amdgcn_sched_barrier(0);
602  });
603 
604  // block_sync_lds();
605  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
606  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
607 
608  i += 1;
609  } while(i < (num_loop - 1));
610  }
611 
612  // tail
613  if constexpr(TailNum == TailNumber::Full)
614  {
615  block_sync_lds();
616  static_for<0, KRepeat, 1>{}([&](auto k0) {
617  static_for<0, MRepeat, 1>{}([&](auto m0) {
618  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
619  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
620  a_block_buf,
621  a_thread_desc_,
622  make_tuple(m0, I0, k0, I0),
623  a_thread_buf);
624  static_for<0, NRepeat, 1>{}([&](auto n0) {
625  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
626  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
627  b_block_buf,
628  b_thread_desc_,
629  make_tuple(n0, I0, k0, I0),
630  b_thread_buf);
631  });
632  });
633 
634  __builtin_amdgcn_sched_barrier(0);
635  if constexpr(k0.value != 0 || KRepeat == 1)
636  {
637  __builtin_amdgcn_s_barrier();
638  __builtin_amdgcn_sched_barrier(0);
639  }
640  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
641  static_for<0, MRepeat, 1>{}([&](auto m0) {
642  static_for<0, NRepeat, 1>{}([&](auto n0) {
645 
646  static_for<0, KPack, 1>{}([&](auto ik) {
647  a_thread_vec.template AsType<ComputeDataType>()(ik) =
648  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
649  make_tuple(m0, I0, k0, k_ + ik))>{}];
650  b_thread_vec.template AsType<ComputeDataType>()(ik) =
651  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
652  make_tuple(n0, I0, k0, k_ + ik))>{}];
653  });
654 
655  using mfma_input_type =
656  typename vector_type<ComputeDataType,
657  xdlops_gemm.K1PerXdlops>::type;
658 
659  constexpr index_t c_offset =
660  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
661 
662  if constexpr(k0.value == KRepeat - 1 &&
663  k_.value == KPerInnerLoop - KPack &&
664  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
665  {
666  __builtin_amdgcn_sched_barrier(0);
667  block_sync_lds();
668  __builtin_amdgcn_sched_barrier(0);
669  }
670  xdlops_gemm.Run(
671  a_thread_vec.template AsType<mfma_input_type>(),
672  b_thread_vec.template AsType<mfma_input_type>(),
673  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
674  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
675  {
676  __builtin_amdgcn_sched_barrier(0);
677  __builtin_amdgcn_s_setprio(1);
678  __builtin_amdgcn_sched_barrier(0);
679  }
680  });
681  });
682  });
683  __builtin_amdgcn_sched_barrier(0);
684  __builtin_amdgcn_s_setprio(0);
685  __builtin_amdgcn_sched_barrier(0);
686  });
687  }
688  }
689 
690  protected:
691  // K->M loopover
692  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
693  make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
694  make_tuple(Number<KPerInnerLoop>{},
695  Number<KRepeat * MRepeat * KPerInnerLoop>{},
696  Number<MRepeat * KPerInnerLoop>{},
697  I1));
698 
699  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
700  make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
701  make_tuple(Number<KPerInnerLoop>{},
702  Number<KRepeat * NRepeat * KPerInnerLoop>{},
703  Number<NRepeat * KPerInnerLoop>{},
704  I1));
705 
707  ComputeDataType,
708  decltype(a_block_desc_m0_m1_m2_k),
709  decltype(a_thread_desc_),
712  3,
713  A_K1,
714  A_K1>;
715 
717  ComputeDataType,
718  decltype(b_block_desc_n0_n1_n2_k),
719  decltype(b_thread_desc_),
722  3,
723  B_K1,
724  B_K1>;
725 
726  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
727  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
728  using Base::c_thread_desc_;
729 };
730 
731 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:208
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:474
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:173
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:37
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10