/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File
blockwise_gemm_pipeline_wmmaops_v1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Naive pipeline with lowest resource request per WGP
11 // GlobalPrefetchStages: 1
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 0
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeTypeA,
21  typename ComputeTypeB,
22  typename AccDataType,
23  typename AWmmaTileDesc,
24  typename BWmmaTileDesc,
25  index_t ABlockTransferSrcScalarPerVector,
26  index_t BBlockTransferSrcScalarPerVector,
27  index_t MPerBlock,
28  index_t NPerBlock,
29  index_t KPerBlock,
30  index_t MPerWmma,
31  index_t NPerWmma,
32  index_t MRepeat,
33  index_t NRepeat,
34  index_t KPack>
36 {
37 };
38 
39 template <index_t BlockSize,
40  typename ADataType,
41  typename BDataType,
42  typename ComputeTypeA,
43  typename ComputeTypeB,
44  typename AccDataType,
45  typename AWmmaTileDesc,
46  typename BWmmaTileDesc,
47  index_t ABlockTransferSrcScalarPerVector,
48  index_t BBlockTransferSrcScalarPerVector,
49  index_t MPerBlock,
50  index_t NPerBlock,
51  index_t KPerBlock,
52  index_t MPerWmma,
53  index_t NPerWmma,
54  index_t MRepeat,
55  index_t NRepeat,
56  index_t KPack>
58  BlockSize,
59  ADataType,
60  BDataType,
61  ComputeTypeA,
62  ComputeTypeB,
63  AccDataType,
64  AWmmaTileDesc,
65  BWmmaTileDesc,
66  ABlockTransferSrcScalarPerVector,
67  BBlockTransferSrcScalarPerVector,
68  MPerBlock,
69  NPerBlock,
70  KPerBlock,
71  MPerWmma,
72  NPerWmma,
73  MRepeat,
74  NRepeat,
75  KPack>
77  ADataType,
78  BDataType,
79  ComputeTypeA,
80  ComputeTypeB,
81  AccDataType,
82  AWmmaTileDesc,
83  BWmmaTileDesc,
84  ABlockTransferSrcScalarPerVector,
85  BBlockTransferSrcScalarPerVector,
86  MPerBlock,
87  NPerBlock,
88  KPerBlock,
89  MPerWmma,
90  NPerWmma,
91  MRepeat,
92  NRepeat,
93  KPack>
94 
95 {
97  ADataType,
98  BDataType,
99  ComputeTypeA,
100  ComputeTypeB,
101  AccDataType,
102  AWmmaTileDesc,
103  BWmmaTileDesc,
104  ABlockTransferSrcScalarPerVector,
105  BBlockTransferSrcScalarPerVector,
106  MPerBlock,
107  NPerBlock,
108  KPerBlock,
109  MPerWmma,
110  NPerWmma,
111  MRepeat,
112  NRepeat,
113  KPack>;
114  using Base::I0;
115 
116  using Base::A_K1;
117  using Base::A_KRow;
118  using Base::B_K1;
119  using Base::B_KRow;
120  using Base::KRepeat;
121  using Base::WmmaK;
122 
123  using Base::wmma_gemm;
124 
125  using Base::CalculateCThreadOriginDataIndex;
126  using Base::
127  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
128  using Base::GetCThreadBuffer;
129  using Base::
130  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
131 
132  using Base::a_block_desc_k0_m0_m1_m2_k1;
133  using Base::b_block_desc_k0_n0_n1_n2_k1;
134 
135  using typename Base::Empty;
136 
137  static constexpr index_t PrefetchStages = 1;
138  static constexpr index_t PrefillStages = 1;
139  static constexpr index_t GlobalBufferNum = 1;
140 
141  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
142 
144  {
145  ignore = num_loop;
146  return TailNumber::Full;
147  }
148 
149  template <bool HasMainLoop,
150  TailNumber TailNum,
151  typename AGridDesc,
152  typename ABlockDesc,
153  typename ABlockTransfer,
154  typename AGridBuffer,
155  typename ABlockBuffer,
156  typename ABlockTransferStep,
157  typename BGridDesc,
158  typename BBlockDesc,
159  typename BBlockTransfer,
160  typename BGridBuffer,
161  typename BBlockBuffer,
162  typename BBlockTransferStep,
163  typename CThreadBuffer,
164  typename BScaleStruct>
165  __device__ void Run(const AGridDesc& a_grid_desc,
166  const ABlockDesc& a_block_desc,
167  ABlockTransfer& a_blockwise_copy,
168  const AGridBuffer& a_grid_buf,
169  ABlockBuffer& a_block_buf,
170  const ABlockTransferStep& a_block_copy_step,
171  const BGridDesc& b_grid_desc,
172  const BBlockDesc& b_block_desc,
173  BBlockTransfer& b_blockwise_copy,
174  const BGridBuffer& b_grid_buf,
175  BBlockBuffer& b_block_buf,
176  const BBlockTransferStep& b_block_copy_step,
177  CThreadBuffer& c_thread_buf,
178  // BScaleThreadCopy
179  BScaleStruct& b_scale_struct,
180  index_t num_loop,
181  index_t num_loop_per_scale) const
182  {
183  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
184  a_thread_desc_.GetElementSpaceSize());
185  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
186  b_thread_desc_.GetElementSpaceSize());
187 
188  // Global prefetch 1
189  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
190  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
191 
192  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
193  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
194 
195  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
196 
197  // Local prefill 1
198  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
199  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
200 
201  // Initialize C
202  c_thread_buf.Clear();
203 
204  auto blockwise_gemm_func = [&]() {
205  static_for<0, KRepeat, 1>{}([&](auto k0) {
206  static_for<0, MRepeat, 1>{}([&](auto m0) {
207  a_thread_copy_.Run(
208  a_block_desc_k0_m0_m1_m2_k1,
209  make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
210  a_block_buf,
211  a_thread_desc_,
212  make_tuple(I0, m0, k0, I0, I0, I0),
213  a_thread_buf);
214  });
215  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
216  {
217  static_for<0, NRepeat, 1>{}([&](auto n0) {
218  b_thread_copy_.Run(
219  b_block_desc_k0_n0_n1_n2_k1,
220  make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
221  b_block_buf,
222  b_thread_desc_,
223  make_tuple(I0, n0, k0, I0, I0, I0),
224  b_thread_buf);
225  });
226  }
227  else
228  {
229  static_for<0, NRepeat, 1>{}([&](auto n0) {
230  b_thread_copy_.Run(
231  b_block_desc_k0_n0_n1_n2_k1,
232  make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
233  b_block_buf,
234  b_scale_struct.b_scale_thread_bufs(
235  I0)[Number<n0 * BScaleStruct::num_scale_k_block +
236  k0 / BScaleStruct::num_scale_krepeat>{}],
237  b_thread_desc_,
238  make_tuple(I0, n0, k0, I0, I0, I0),
239  b_thread_buf);
240  });
241  }
242 
243  static_for<0, MRepeat, 1>{}([&](auto m0) {
244  static_for<0, NRepeat, 1>{}([&](auto n0) {
245  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
246  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
247 
248  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
249  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
250  a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
251  Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
252  });
253  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
254  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
255  b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
256  Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
257  });
258 
259  using wmma_input_type_a =
260  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
261  using wmma_input_type_b =
262  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
263 
264  constexpr index_t c_offset =
265  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
266 
267  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
268  b_thread_vec.template AsType<wmma_input_type_b>(),
269  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
270  });
271  });
272  });
273  };
274 
275  // main body
276  if constexpr(HasMainLoop)
277  {
278  index_t i = 0;
279  do
280  {
281  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
282  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
283 
284  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
285  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
286 
287  block_sync_lds();
288  blockwise_gemm_func();
289 
290  block_sync_lds();
291  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
292  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
293  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
294 
295  i += 1;
296  } while(i < (num_loop - 1));
297  }
298 
299  // tail
300  if constexpr(TailNum == TailNumber::Full)
301  {
302  block_sync_lds();
303  blockwise_gemm_func();
304  }
305  }
306 
307  protected:
308  using Base::a_thread_copy_;
309  using Base::a_thread_desc_;
310  using Base::b_thread_copy_;
311  using Base::b_thread_desc_;
312  using Base::c_thread_desc_;
313 };
314 
315 template <index_t BlockSize,
316  typename ADataType,
317  typename BDataType,
318  typename ComputeTypeA,
319  typename ComputeTypeB,
320  typename AccDataType,
321  typename AWmmaTileDesc,
322  typename BWmmaTileDesc,
323  index_t ABlockTransferSrcScalarPerVector,
324  index_t BBlockTransferSrcScalarPerVector,
325  index_t MPerBlock,
326  index_t NPerBlock,
327  index_t KPerBlock,
328  index_t MPerWmma,
329  index_t NPerWmma,
330  index_t MRepeat,
331  index_t NRepeat,
332  index_t KPack>
334  BlockSize,
335  ADataType,
336  BDataType,
337  ComputeTypeA,
338  ComputeTypeB,
339  AccDataType,
340  AWmmaTileDesc,
341  BWmmaTileDesc,
342  ABlockTransferSrcScalarPerVector,
343  BBlockTransferSrcScalarPerVector,
344  MPerBlock,
345  NPerBlock,
346  KPerBlock,
347  MPerWmma,
348  NPerWmma,
349  MRepeat,
350  NRepeat,
351  KPack>
353  ADataType,
354  BDataType,
355  ComputeTypeA,
356  ComputeTypeB,
357  AccDataType,
358  AWmmaTileDesc,
359  BWmmaTileDesc,
360  ABlockTransferSrcScalarPerVector,
361  BBlockTransferSrcScalarPerVector,
362  MPerBlock,
363  NPerBlock,
364  KPerBlock,
365  MPerWmma,
366  NPerWmma,
367  MRepeat,
368  NRepeat,
369  KPack>
370 
371 {
373  ADataType,
374  BDataType,
375  ComputeTypeA,
376  ComputeTypeB,
377  AccDataType,
378  AWmmaTileDesc,
379  BWmmaTileDesc,
380  ABlockTransferSrcScalarPerVector,
381  BBlockTransferSrcScalarPerVector,
382  MPerBlock,
383  NPerBlock,
384  KPerBlock,
385  MPerWmma,
386  NPerWmma,
387  MRepeat,
388  NRepeat,
389  KPack>;
390  using Base::I0;
391  using Base::I1;
392 
393  using Base::A_K1;
394  using Base::A_KRow;
395  using Base::B_K1;
396  using Base::B_KRow;
397  using Base::KRepeat;
398  using Base::WmmaK;
399 
400  using Base::wmma_gemm;
401 
402  using Base::CalculateCThreadOriginDataIndex;
403  using Base::
404  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
405  using Base::GetCThreadBuffer;
406  using Base::
407  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
408 
409  using Base::a_block_desc_k0_m0_m1_m2_k1;
410  using Base::b_block_desc_k0_n0_n1_n2_k1;
411 
412  using typename Base::Empty;
413 
415  static constexpr index_t KRepeatPerCluster = math::max(KRepeat / NumKClusters, 1);
416 
417  static constexpr index_t PrefetchStages = 1;
418  static constexpr index_t PrefillStages = 1;
419  static constexpr index_t GlobalBufferNum = 1;
420 
421  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
422 
424  {
425  ignore = num_loop;
426  return TailNumber::Full;
427  }
428 
429  template <bool HasMainLoop,
430  TailNumber TailNum,
431  typename AGridDesc,
432  typename ABlockDesc,
433  typename ABlockTransfer,
434  typename AGridBuffer,
435  typename ABlockBuffer,
436  typename ABlockTransferStep,
437  typename BGridDesc,
438  typename BBlockDesc,
439  typename BBlockTransfer,
440  typename BGridBuffer,
441  typename BBlockBuffer,
442  typename BBlockTransferStep,
443  typename CThreadBuffer,
444  typename BScaleStruct>
445  __device__ void Run(const AGridDesc& a_grid_desc,
446  const ABlockDesc& a_block_desc,
447  ABlockTransfer& a_blockwise_copy,
448  const AGridBuffer& a_grid_buf,
449  ABlockBuffer& a_block_buf,
450  const ABlockTransferStep& a_block_copy_step,
451  const BGridDesc& b_grid_desc,
452  const BBlockDesc& b_block_desc,
453  BBlockTransfer& b_blockwise_copy,
454  const BGridBuffer& b_grid_buf,
455  BBlockBuffer& b_block_buf,
456  const BBlockTransferStep& b_block_copy_step,
457  CThreadBuffer& c_thread_buf,
458  // BScaleThreadCopy
459  BScaleStruct& b_scale_struct,
460  index_t num_loop,
461  index_t num_loop_per_scale) const
462  {
463  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
464  a_thread_desc_.GetElementSpaceSize());
465  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
466  b_thread_desc_.GetElementSpaceSize());
467 
468  // Global prefetch 1
469  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
470  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
471 
472  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
473  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
474 
475  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
476 
477  // Local prefill 1
478  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
479  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
480 
481  // Initialize C
482  c_thread_buf.Clear();
483 
484  auto blockwise_gemm_func = [&]() {
485  static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
486  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
487  static_for<0, MRepeat, 1>{}([&](auto m0) {
488  a_thread_copy_.Run(
489  a_block_desc_k0_m0_m1_m2_k1,
490  make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{},
491  m0,
492  I0,
493  I0,
494  I0,
495  I0),
496  a_block_buf,
497  a_thread_desc_,
498  make_tuple(I0, m0, k0_inner, I0, I0, I0),
499  a_thread_buf);
500  });
501  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
502  {
503  static_for<0, NRepeat, 1>{}([&](auto n0) {
504  b_thread_copy_.Run(
505  b_block_desc_k0_n0_n1_n2_k1,
506  make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
507  n0,
508  I0,
509  I0,
510  I0,
511  I0),
512  b_block_buf,
513  b_thread_desc_,
514  make_tuple(I0, n0, k0_inner, I0, I0, I0),
515  b_thread_buf);
516  });
517  }
518  else
519  {
520  static_for<0, NRepeat, 1>{}([&](auto n0) {
521  b_thread_copy_.Run(
522  b_block_desc_k0_n0_n1_n2_k1,
523  make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
524  n0,
525  I0,
526  I0,
527  I0,
528  I0),
529  b_block_buf,
530  b_scale_struct.b_scale_thread_bufs(I0)[Number<
531  n0 * BScaleStruct::num_scale_k_block +
532  (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
533  b_thread_desc_,
534  make_tuple(I0, n0, k0_inner, I0, I0, I0),
535  b_thread_buf);
536  });
537  }
538  });
539 
540  __builtin_amdgcn_sched_barrier(0);
541  // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
542  // but except the first, as we can shorten non-MAC cluster a bit and there's no
543  // observable negative impact. The desired effect is waves in a workgroup
544  // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
545  // resource from other workgroups and reducing the chance of latency hiding by
546  // waiting for the rest of the workgroup at the eventual sync point.
547  if constexpr(k0_offset != 0 || KRepeat == 1)
548  {
549  __builtin_amdgcn_s_barrier();
550  __builtin_amdgcn_sched_barrier(0);
551  }
552  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
553  static_for<0, MRepeat, 1>{}([&](auto m0) {
554  static_for<0, NRepeat, 1>{}([&](auto n0) {
555  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
556  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
557 
558  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
559  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
560  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
562  m0,
563  k0_inner,
564  I0,
565  I0,
566  Number<ik % A_K1>{}))>{}];
567  });
568  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
569  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
570  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
572  n0,
573  k0_inner,
574  I0,
575  I0,
576  Number<ik % B_K1>{}))>{}];
577  });
578 
579  using wmma_input_type_a =
580  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
581  using wmma_input_type_b =
582  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
583 
584  constexpr index_t c_offset =
585  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
586 
587  // The block_sync_lds() here performs double duty:
588  // A) safeguard against data hazard.
589  // B) reduce VMEM FIFO congestion by applying small delays to
590  // different wavefronts.
591  // It is performed near the end of MAC cluster to minimize lgkmcnt
592  // penalty
593  if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
594  n0 == NRepeat - 1)
595  {
596  __builtin_amdgcn_sched_barrier(0);
597  block_sync_lds();
598  __builtin_amdgcn_sched_barrier(0);
599  }
600  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
601  b_thread_vec.template AsType<wmma_input_type_b>(),
602  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
603  if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
604  {
605  __builtin_amdgcn_sched_barrier(0);
606  __builtin_amdgcn_s_setprio(1);
607  __builtin_amdgcn_sched_barrier(0);
608  }
609  });
610  });
611  });
612  __builtin_amdgcn_sched_barrier(0);
613  __builtin_amdgcn_s_setprio(0);
614  __builtin_amdgcn_sched_barrier(0);
615  });
616  };
617 
618  // main body
619  if constexpr(HasMainLoop)
620  {
621  index_t i = 0;
622  do
623  {
624  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
625  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
626 
627  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
628  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
629 
630  block_sync_lds();
631  blockwise_gemm_func();
632 
633  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
634  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
635  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
636 
637  i += 1;
638  } while(i < (num_loop - 1));
639  }
640 
641  // tail
642  if constexpr(TailNum == TailNumber::Full)
643  {
644  block_sync_lds();
645  blockwise_gemm_func();
646  }
647  }
648 
649  protected:
650  static constexpr auto a_thread_desc_ =
652  Number<MRepeat>{},
653  Number<KRepeatPerCluster>{},
654  I1,
655  I1,
656  Number<A_K1>{}),
657  make_tuple(Number<A_K1>{},
658  Number<KPack / A_KRow>{},
659  Number<KPack / A_KRow * MRepeat>{},
660  I0,
661  I0,
662  I1));
663 
664  static constexpr auto b_thread_desc_ =
666  Number<NRepeat>{},
667  Number<KRepeatPerCluster>{},
668  I1,
669  I1,
670  Number<B_K1>{}),
671  make_tuple(Number<B_K1>{},
672  Number<KPack / B_KRow>{},
673  Number<KPack / B_KRow * NRepeat>{},
674  I0,
675  I0,
676  I1));
677 
678  using AThreadCopy =
680  ComputeTypeA,
681  decltype(a_block_desc_k0_m0_m1_m2_k1),
682  decltype(a_thread_desc_),
683  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
685  5,
686  A_K1,
687  A_K1>;
688 
689  using BThreadCopy =
691  ComputeTypeB,
692  decltype(b_block_desc_k0_n0_n1_n2_k1),
693  decltype(b_thread_desc_),
694  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
696  5,
697  B_K1,
698  B_K1>;
699 
700  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
701  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
702  using Base::c_thread_desc_;
703 };
704 
705 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:208
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:95
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:445
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:165
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:36
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10