/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File
blockwise_gemm_pipeline_wmmaops_v1.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Naive pipeline with lowest resource request per WGP
11 
12 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
13  index_t BlockSize,
14  typename ADataType,
15  typename BDataType,
16  typename ComputeTypeA,
17  typename ComputeTypeB,
18  typename AccDataType,
19  typename AWmmaTileDesc,
20  typename BWmmaTileDesc,
21  index_t ABlockTransferSrcScalarPerVector,
22  index_t BBlockTransferSrcScalarPerVector,
23  index_t MPerBlock,
24  index_t NPerBlock,
25  index_t KPerBlock,
26  index_t MPerWmma,
27  index_t NPerWmma,
28  index_t MRepeat,
29  index_t NRepeat,
30  index_t KPack,
31  index_t KInner,
32  bool TransposeC = false,
33  bool BSkipLDS = false>
35 {
36 };
37 
38 template <index_t BlockSize,
39  typename ADataType,
40  typename BDataType,
41  typename ComputeTypeA,
42  typename ComputeTypeB,
43  typename AccDataType,
44  typename AWmmaTileDesc,
45  typename BWmmaTileDesc,
46  index_t ABlockTransferSrcScalarPerVector,
47  index_t BBlockTransferSrcScalarPerVector,
48  index_t MPerBlock,
49  index_t NPerBlock,
50  index_t KPerBlock,
51  index_t MPerWmma,
52  index_t NPerWmma,
53  index_t MRepeat,
54  index_t NRepeat,
55  index_t KPack,
56  index_t KInner,
57  bool TransposeC>
59  BlockSize,
60  ADataType,
61  BDataType,
62  ComputeTypeA,
63  ComputeTypeB,
64  AccDataType,
65  AWmmaTileDesc,
66  BWmmaTileDesc,
67  ABlockTransferSrcScalarPerVector,
68  BBlockTransferSrcScalarPerVector,
69  MPerBlock,
70  NPerBlock,
71  KPerBlock,
72  MPerWmma,
73  NPerWmma,
74  MRepeat,
75  NRepeat,
76  KPack,
77  KInner,
78  TransposeC,
79  false>
81  ADataType,
82  BDataType,
83  ComputeTypeA,
84  ComputeTypeB,
85  AccDataType,
86  AWmmaTileDesc,
87  BWmmaTileDesc,
88  ABlockTransferSrcScalarPerVector,
89  BBlockTransferSrcScalarPerVector,
90  MPerBlock,
91  NPerBlock,
92  KPerBlock,
93  MPerWmma,
94  NPerWmma,
95  MRepeat,
96  NRepeat,
97  KPack,
98  KInner,
99  TransposeC>
100 {
101  // GlobalPrefetchStages: 1
102  // LocalPreFillStages: 1
103  // LocalPreFetchStages: 0
104  // LocalSharedMemoryBuffer: 1
106  ADataType,
107  BDataType,
108  ComputeTypeA,
109  ComputeTypeB,
110  AccDataType,
111  AWmmaTileDesc,
112  BWmmaTileDesc,
113  ABlockTransferSrcScalarPerVector,
114  BBlockTransferSrcScalarPerVector,
115  MPerBlock,
116  NPerBlock,
117  KPerBlock,
118  MPerWmma,
119  NPerWmma,
120  MRepeat,
121  NRepeat,
122  KPack,
123  KInner,
124  TransposeC>;
125  using Base::I0;
126  using Base::I1;
127  using typename Base::HotLoopInstList;
128 
129  using Base::A_K1;
130  using Base::A_KRow;
131  using Base::B_K1;
132  using Base::B_KRow;
133  using Base::KRepeat;
134  using Base::WmmaK;
135 
136  using Base::wmma_gemm;
137 
138  using Base::CalculateCThreadOriginDataIndex;
139  using Base::
140  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
141  using Base::GetCThreadBuffer;
142  using Base::
143  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
144 
145  using Base::a_block_desc_k0_m0_m1_m2_k1;
146  using Base::b_block_desc_k0_n0_n1_n2_k1;
147 
148  using typename Base::Empty;
149 
150  static constexpr index_t PrefetchStages = 1;
151  static constexpr index_t PrefillStages = 1;
152  static constexpr index_t GlobalBufferNum = 1;
153 
154  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
155 
157  {
158  ignore = num_loop;
159  return TailNumber::Full;
160  }
161 
162  template <bool HasMainLoop,
163  TailNumber TailNum,
164  typename AGridDesc,
165  typename ABlockDesc,
166  typename ABlockTransfer,
167  typename AGridBuffer,
168  typename ABlockBuffer,
169  typename ABlockTransferStep,
170  typename BGridDesc,
171  typename BBlockDesc,
172  typename BBlockTransfer,
173  typename BGridBuffer,
174  typename BBlockBuffer,
175  typename BBlockTransferStep,
176  typename CThreadBuffer,
177  typename AScaleStruct,
178  typename BScaleStruct,
179  typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
180  __device__ void Run(const AGridDesc& a_grid_desc,
181  const ABlockDesc& a_block_desc,
182  ABlockTransfer& a_blockwise_copy,
183  const AGridBuffer& a_grid_buf,
184  ABlockBuffer& a_block_buf,
185  const ABlockTransferStep& a_block_copy_step,
186  const BGridDesc& b_grid_desc,
187  const BBlockDesc& b_block_desc,
188  BBlockTransfer& b_blockwise_copy,
189  const BGridBuffer& b_grid_buf,
190  BBlockBuffer& b_block_buf,
191  const BBlockTransferStep& b_block_copy_step,
192  CThreadBuffer& c_thread_buf,
193  AScaleStruct&,
194  BScaleStruct& b_scale_struct,
195  index_t num_loop,
196  index_t num_loop_per_scale) const
197  {
198  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
199 
200  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
201  a_thread_desc_.GetElementSpaceSize());
202  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
203  b_thread_desc_.GetElementSpaceSize());
204 
205  // Global prefetch 1
206  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
207  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
208 
209  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
210  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
211 
212  // Scales global load
213  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
214 
215  // Local prefill 1
216  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
217  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
218 
219  // Initialize C
220  c_thread_buf.Clear();
221 
222  auto blockwise_gemm_func = [&]() {
223  // Local load
224  static_for<0, KRepeat, 1>{}([&](auto k0) {
225  static_for<0, MRepeat, 1>{}([&](auto m0) {
226  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
227  make_tuple(I0, m0, k0, I0, I0, I0, I0),
228  a_block_buf,
229  a_thread_desc_,
230  make_tuple(I0, I0, I0, I0, I0, I0, I0),
231  a_thread_buf);
232  if constexpr(m0 == I0)
233  {
234  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
235  {
236  static_for<0, NRepeat, 1>{}([&](auto n0) {
237  b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
238  make_tuple(I0, n0, k0, I0, I0, I0, I0),
239  b_block_buf,
240  b_thread_desc_,
241  make_tuple(I0, n0, I0, I0, I0, I0, I0),
242  b_thread_buf);
243  });
244  }
245  else
246  {
247  static_for<0, NRepeat, 1>{}([&](auto n0) {
248  b_thread_copy_.Run(
249  b_block_desc_k0_n0_n1_n2_k1,
250  make_tuple(I0, n0, k0, I0, I0, I0, I0),
251  b_block_buf,
252  b_scale_struct.scale_thread_bufs(
253  I0)[Number<n0 * BScaleStruct::num_scale_k_block +
254  k0 / BScaleStruct::num_scale_krepeat>{}],
255  b_thread_desc_,
256  make_tuple(I0, n0, I0, I0, I0, I0, I0),
257  b_thread_buf);
258  });
259  }
260  }
261 
262  static_for<0, KInner, 1>{}([&](auto k_inner) {
263  static_for<0, NRepeat, 1>{}([&](auto n0) {
264  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
265  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
266 
267  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
268  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
269  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
270  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
272  I0,
273  I0,
274  I0,
275  I0,
276  I0,
277  Number<kk % A_K1>{}))>{}];
278  });
279  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
280  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
281  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
282  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
284  n0,
285  I0,
286  I0,
287  I0,
288  I0,
289  Number<kk % B_K1>{}))>{}];
290  });
291 
292  using wmma_input_type_a =
293  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
294  using wmma_input_type_b =
295  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
296 
297  constexpr index_t c_offset =
298  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
299 
300  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
301  b_thread_vec.template AsType<wmma_input_type_b>(),
302  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
303  });
304  });
305  });
306  });
307  };
308 
309  // main body
310  if constexpr(HasMainLoop)
311  {
312  index_t i = 0;
313  do
314  {
315  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
316  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
317 
318  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
319  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
320 
321  block_sync_lds();
322  blockwise_gemm_func();
323 
324  block_sync_lds();
325  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
326  if constexpr(ck::is_same<BScaleStruct, Empty>::value == false)
327  {
328  block_sync_lds();
329  }
330  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
331  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
332 
333  constexpr index_t num_ds_write_inst =
334  HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
335 
336  constexpr index_t num_buffer_load_inst = HotLoopInstList::A_Buffer_Load_Inst_Num +
337  HotLoopInstList::B_Buffer_Load_Inst_Num;
339  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
340  });
341  static_for<0, KRepeat, 1>{}([&](auto) {
342  static_for<0, MRepeat, 1>{}([&](auto m0) {
343  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
344  if constexpr(m0 == I0)
345  {
346  static_for<0, NRepeat, 1>{}([&](auto) {
347  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
348  });
349  }
350  static_for<0, KInner, 1>{}([&](auto) {
351  static_for<0, NRepeat, 1>{}([&](auto) {
352  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
353  });
354  });
355  });
356  });
358  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
359  });
360 
361  i += 1;
362  } while(i < (num_loop - 1));
363  }
364 
365  // tail
366  if constexpr(TailNum == TailNumber::Full)
367  {
368  block_sync_lds();
369  blockwise_gemm_func();
370  }
371  }
372 
373  template <bool HasMainLoop,
374  TailNumber TailNum,
375  typename AGridDesc,
376  typename ABlockDesc,
377  typename ABlockTransfer,
378  typename AGridBuffer,
379  typename ABlockBuffer,
380  typename ABlockTransferStep,
381  typename BGridDesc,
382  typename BBlockDesc,
383  typename BBlockTransfer,
384  typename BGridBuffer,
385  typename BBlockBuffer,
386  typename BBlockTransferStep,
387  typename CThreadBuffer,
388  typename AScaleStruct,
389  typename BScaleStruct,
391  !ck::is_same_v<BScaleStruct, Empty>,
392  bool>::type = false>
393  __device__ void Run(const AGridDesc& a_grid_desc,
394  const ABlockDesc& a_block_desc,
395  ABlockTransfer& a_blockwise_copy,
396  const AGridBuffer& a_grid_buf,
397  ABlockBuffer& a_block_buf,
398  const ABlockTransferStep& a_block_copy_step,
399  const BGridDesc& b_grid_desc,
400  const BBlockDesc& b_block_desc,
401  BBlockTransfer& b_blockwise_copy,
402  const BGridBuffer& b_grid_buf,
403  BBlockBuffer& b_block_buf,
404  const BBlockTransferStep& b_block_copy_step,
405  CThreadBuffer& c_thread_buf,
406  AScaleStruct& a_scale_struct,
407  BScaleStruct& b_scale_struct,
408  index_t num_loop,
409  index_t num_loop_per_scale) const
410  {
411  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
412  static constexpr auto NumScaleKBlock =
413  Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
414 
415  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
416  Base::a_thread_desc_.GetElementSpaceSize());
417  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
418  Base::b_thread_desc_.GetElementSpaceSize());
419 
420  using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
421  auto c_scale_struct = CScaleStruct{};
422 
423  // Global prefetch 1
424  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
425  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
426 
427  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
428  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
429 
430  // Scales global load
431  a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
432  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
433 
434  c_scale_struct.Load(a_scale_struct, b_scale_struct);
435 
436  // Local prefill 1
437  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
438  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
439 
440  // Initialize C
441  c_thread_buf.Clear();
442 
443  auto blockwise_gemm_func = [&]() {
444  // Local load
445  static_for<0, KRepeat, 1>{}([&](auto k0) {
446  static_for<0, MRepeat, 1>{}([&](auto m0) {
447  Base::a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
448  make_tuple(I0, m0, k0, I0, I0, I0, I0),
449  a_block_buf,
450  Base::a_thread_desc_,
451  make_tuple(I0, m0, k0, I0, I0, I0, I0),
452  a_thread_buf);
453  });
454  static_for<0, NRepeat, 1>{}([&](auto n0) {
455  Base::b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
456  make_tuple(I0, n0, k0, I0, I0, I0, I0),
457  b_block_buf,
458  Base::b_thread_desc_,
459  make_tuple(I0, n0, k0, I0, I0, I0, I0),
460  b_thread_buf);
461  });
462  });
463 
464  static_for<0, MRepeat, 1>{}([&](auto m0) {
465  static_for<0, NRepeat, 1>{}([&](auto n0) {
466  static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
467  c_scale_struct.Clear();
468  static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
469  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
470  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
471 
472  static_for<0, KInner, 1>{}([&](auto k_inner) {
473  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
474  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
475  constexpr index_t k_index =
476  kscale0 * (KRepeat / NumScaleKBlock) + k0;
477  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
478  a_thread_buf[Number<Base::a_thread_desc_.CalculateOffset(
480  m0,
481  k_index,
482  I0,
483  I0,
484  I0,
485  Number<kk % A_K1>{}))>{}];
486  });
487  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
488  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
489  constexpr index_t k_index =
490  kscale0 * (KRepeat / NumScaleKBlock) + k0;
491  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
492  b_thread_buf[Number<Base::b_thread_desc_.CalculateOffset(
494  n0,
495  k_index,
496  I0,
497  I0,
498  I0,
499  Number<kk % B_K1>{}))>{}];
500  });
501 
502  using wmma_input_type_a =
503  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
504  using wmma_input_type_b =
505  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
506 
507  wmma_gemm.Run(
508  a_thread_vec.template AsType<wmma_input_type_a>(),
509  b_thread_vec.template AsType<wmma_input_type_b>(),
510  c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
511  Number<0>{}));
512  });
513  });
514  c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
515  });
516  });
517  });
518  };
519 
520  // main body
521  if constexpr(HasMainLoop)
522  {
523  index_t i = 0;
524  do
525  {
526  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
527  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
528 
529  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
530  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
531 
532  a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
533  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
534 
535  block_sync_lds();
536  blockwise_gemm_func();
537 
538  block_sync_lds();
539  c_scale_struct.Load(a_scale_struct, b_scale_struct);
540 
541  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
542  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
543 
544  i += 1;
545  } while(i < (num_loop - 1));
546  }
547 
548  // tail
549  if constexpr(TailNum == TailNumber::Full)
550  {
551  block_sync_lds();
552  blockwise_gemm_func();
553  }
554  }
555 
556  protected:
557  // A[MRepeat, I1, I1, KPack]
558  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
559  make_tuple(Number<KPack / A_K1 / A_KRow>{}, I1, I1, I1, I1, I1, Number<A_K1>{}));
560 
561  // B[NRepeat, N1, N2, KPack]
562  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
563  Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, I1, Number<B_K1>{}));
564 
565  using AThreadCopy =
567  ComputeTypeA,
568  decltype(a_block_desc_k0_m0_m1_m2_k1),
569  decltype(a_thread_desc_),
570  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
572  6,
573  A_K1,
574  A_K1>;
575 
576  using BThreadCopy =
578  ComputeTypeB,
579  decltype(b_block_desc_k0_n0_n1_n2_k1),
580  decltype(b_thread_desc_),
581  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
583  6,
584  B_K1,
585  B_K1>;
586 
587  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
588  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
589  using Base::c_thread_desc_;
590 };
591 
592 template <index_t BlockSize,
593  typename ADataType,
594  typename BDataType,
595  typename ComputeTypeA,
596  typename ComputeTypeB,
597  typename AccDataType,
598  typename AWmmaTileDesc,
599  typename BWmmaTileDesc,
600  index_t ABlockTransferSrcScalarPerVector,
601  index_t BBlockTransferSrcScalarPerVector,
602  index_t MPerBlock,
603  index_t NPerBlock,
604  index_t KPerBlock,
605  index_t MPerWmma,
606  index_t NPerWmma,
607  index_t MRepeat,
608  index_t NRepeat,
609  index_t KPack,
610  index_t KInner,
611  bool TransposeC>
613  BlockSize,
614  ADataType,
615  BDataType,
616  ComputeTypeA,
617  ComputeTypeB,
618  AccDataType,
619  AWmmaTileDesc,
620  BWmmaTileDesc,
621  ABlockTransferSrcScalarPerVector,
622  BBlockTransferSrcScalarPerVector,
623  MPerBlock,
624  NPerBlock,
625  KPerBlock,
626  MPerWmma,
627  NPerWmma,
628  MRepeat,
629  NRepeat,
630  KPack,
631  KInner,
632  TransposeC,
633  false>
635  ADataType,
636  BDataType,
637  ComputeTypeA,
638  ComputeTypeB,
639  AccDataType,
640  AWmmaTileDesc,
641  BWmmaTileDesc,
642  ABlockTransferSrcScalarPerVector,
643  BBlockTransferSrcScalarPerVector,
644  MPerBlock,
645  NPerBlock,
646  KPerBlock,
647  MPerWmma,
648  NPerWmma,
649  MRepeat,
650  NRepeat,
651  KPack,
652  KInner,
653  TransposeC>
654 {
655  // GlobalPrefetchStages: 1
656  // LocalPreFillStages: 1
657  // LocalPreFetchStages: 0
658  // LocalSharedMemoryBuffer: 1
660  ADataType,
661  BDataType,
662  ComputeTypeA,
663  ComputeTypeB,
664  AccDataType,
665  AWmmaTileDesc,
666  BWmmaTileDesc,
667  ABlockTransferSrcScalarPerVector,
668  BBlockTransferSrcScalarPerVector,
669  MPerBlock,
670  NPerBlock,
671  KPerBlock,
672  MPerWmma,
673  NPerWmma,
674  MRepeat,
675  NRepeat,
676  KPack,
677  KInner,
678  TransposeC>;
679  using Base::I0;
680  using Base::I1;
681 
682  using Base::A_K1;
683  using Base::A_KRow;
684  using Base::B_K1;
685  using Base::B_KRow;
686  using Base::KRepeat;
687  using Base::WmmaK;
688 
689  using Base::wmma_gemm;
690 
691  using Base::CalculateCThreadOriginDataIndex;
692  using Base::
693  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
694  using Base::GetCThreadBuffer;
695  using Base::
696  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
697 
698  using Base::a_block_desc_k0_m0_m1_m2_k1;
699  using Base::b_block_desc_k0_n0_n1_n2_k1;
700 
701  using typename Base::Empty;
702 
704  static constexpr index_t KRepeatPerCluster = math::max(KRepeat / NumKClusters, 1);
705 
706  static constexpr index_t PrefetchStages = 1;
707  static constexpr index_t PrefillStages = 1;
708  static constexpr index_t GlobalBufferNum = 1;
709 
710  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
711 
713  {
714  ignore = num_loop;
715  return TailNumber::Full;
716  }
717 
718  template <typename AScaleStruct, typename BScaleStruct>
719  struct KLoopParams
720  {
721  static constexpr auto KRepeatNoScale = 1;
722  static constexpr auto NumScaleKBlock =
723  Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
724  static constexpr auto KRepeatPerNumScaleKBlock = KRepeatPerCluster / NumScaleKBlock;
725  };
726 
727  template <>
728  struct KLoopParams<Empty, Empty>
729  {
730  static constexpr index_t KRepeatNoScale = KRepeatPerCluster;
731  static constexpr index_t NumScaleKBlock = 1;
732  static constexpr index_t KRepeatPerNumScaleKBlock = 1;
733  };
734 
735  template <bool HasMainLoop,
736  TailNumber TailNum,
737  typename AGridDesc,
738  typename ABlockDesc,
739  typename ABlockTransfer,
740  typename AGridBuffer,
741  typename ABlockBuffer,
742  typename ABlockTransferStep,
743  typename BGridDesc,
744  typename BBlockDesc,
745  typename BBlockTransfer,
746  typename BGridBuffer,
747  typename BBlockBuffer,
748  typename BBlockTransferStep,
749  typename CThreadBuffer,
750  typename AScaleStruct,
751  typename BScaleStruct,
752  typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
753  __device__ void Run(const AGridDesc& a_grid_desc,
754  const ABlockDesc& a_block_desc,
755  ABlockTransfer& a_blockwise_copy,
756  const AGridBuffer& a_grid_buf,
757  ABlockBuffer& a_block_buf,
758  const ABlockTransferStep& a_block_copy_step,
759  const BGridDesc& b_grid_desc,
760  const BBlockDesc& b_block_desc,
761  BBlockTransfer& b_blockwise_copy,
762  const BGridBuffer& b_grid_buf,
763  BBlockBuffer& b_block_buf,
764  const BBlockTransferStep& b_block_copy_step,
765  CThreadBuffer& c_thread_buf,
766  AScaleStruct&,
767  BScaleStruct& b_scale_struct,
768  index_t num_loop,
769  index_t num_loop_per_scale) const
770  {
771  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
772 
773  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
774  a_thread_desc_.GetElementSpaceSize());
775  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
776  b_thread_desc_.GetElementSpaceSize());
777 
778  // Global prefetch 1
779  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
780  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
781 
782  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
783  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
784 
785  // Scales global load
786  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
787 
788  // Local prefill 1
789  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
790  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
791 
792  // Initialize C
793  c_thread_buf.Clear();
794 
795  auto blockwise_gemm_func = [&]() {
796  static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
797  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
798  static_for<0, MRepeat, 1>{}([&](auto m0) {
799  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
800  make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0),
801  a_block_buf,
802  a_thread_desc_,
803  make_tuple(I0, m0, k0_inner, I0, I0, I0, I0),
804  a_thread_buf);
805  });
806  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
807  {
808  static_for<0, NRepeat, 1>{}([&](auto n0) {
809  b_thread_copy_.Run(
810  b_block_desc_k0_n0_n1_n2_k1,
811  make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
812  b_block_buf,
813  b_thread_desc_,
814  make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
815  b_thread_buf);
816  });
817  }
818  else
819  {
820  static_for<0, NRepeat, 1>{}([&](auto n0) {
821  b_thread_copy_.Run(
822  b_block_desc_k0_n0_n1_n2_k1,
823  make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
824  b_block_buf,
825  b_scale_struct.scale_thread_bufs(I0)[Number<
826  n0 * BScaleStruct::num_scale_k_block +
827  (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
828  b_thread_desc_,
829  make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
830  b_thread_buf);
831  });
832  }
833  });
834 
835  __builtin_amdgcn_sched_barrier(0);
836  // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
837  // but except the first, as we can shorten non-MAC cluster a bit and there's no
838  // observable negative impact. The desired effect is waves in a workgroup
839  // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
840  // resource from other workgroups and reducing the chance of latency hiding by
841  // waiting for the rest of the workgroup at the eventual sync point.
842  if constexpr(k0_offset != 0 || KRepeat == 1)
843  {
844  __builtin_amdgcn_s_barrier();
845  __builtin_amdgcn_sched_barrier(0);
846  }
847  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
848  static_for<0, KInner, 1>{}([&](auto k_inner) {
849  static_for<0, MRepeat, 1>{}([&](auto m0) {
850  static_for<0, NRepeat, 1>{}([&](auto n0) {
851  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
852  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
853 
854  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
855  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
856  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
857  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
859  m0,
860  k0_inner,
861  I0,
862  I0,
863  I0,
864  Number<kk % A_K1>{}))>{}];
865  });
866  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
867  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
868  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
869  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
871  n0,
872  k0_inner,
873  I0,
874  I0,
875  I0,
876  Number<kk % B_K1>{}))>{}];
877  });
878 
879  using wmma_input_type_a =
880  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
881  using wmma_input_type_b =
882  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
883 
884  constexpr index_t c_offset =
885  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
886 
887  // The block_sync_lds() here performs double duty:
888  // A) safeguard against data hazard.
889  // B) reduce VMEM FIFO congestion by applying small delays to
890  // different wavefronts.
891  // It is performed near the end of MAC cluster to minimize lgkmcnt
892  // penalty
893  if constexpr(k0_offset + k0_inner == KRepeat - 1 &&
894  m0 == MRepeat - 1 && n0 == NRepeat - 1)
895  {
896  __builtin_amdgcn_sched_barrier(0);
897  block_sync_lds();
898  __builtin_amdgcn_sched_barrier(0);
899  }
900  wmma_gemm.Run(
901  a_thread_vec.template AsType<wmma_input_type_a>(),
902  b_thread_vec.template AsType<wmma_input_type_b>(),
903  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
904  if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
905  {
906  __builtin_amdgcn_sched_barrier(0);
907  __builtin_amdgcn_s_setprio(1);
908  __builtin_amdgcn_sched_barrier(0);
909  }
910  });
911  });
912  });
913  });
914 
915  __builtin_amdgcn_sched_barrier(0);
916  __builtin_amdgcn_s_setprio(0);
917  __builtin_amdgcn_sched_barrier(0);
918  });
919  };
920 
921  // main body
922  if constexpr(HasMainLoop)
923  {
924  index_t i = 0;
925  do
926  {
927  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
928  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
929 
930  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
931  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
932 
933  block_sync_lds();
934  blockwise_gemm_func();
935 
936  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
937  if constexpr(ck::is_same<BScaleStruct, Empty>::value == false)
938  {
939  block_sync_lds();
940  }
941  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
942  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
943 
944  i += 1;
945  } while(i < (num_loop - 1));
946  }
947 
948  // tail
949  if constexpr(TailNum == TailNumber::Full)
950  {
951  block_sync_lds();
952  blockwise_gemm_func();
953  }
954  }
955 
956  protected:
957  static constexpr auto a_thread_desc_ =
959  Number<MRepeat>{},
960  Number<KRepeatPerCluster>{},
961  I1,
962  I1,
963  I1,
964  Number<A_K1>{}),
965  make_tuple(Number<A_K1>{},
966  Number<KPack / A_KRow>{},
967  Number<KPack / A_KRow * MRepeat>{},
968  I0,
969  I0,
970  I0,
971  I1));
972 
973  static constexpr auto b_thread_desc_ =
975  Number<NRepeat>{},
976  Number<KRepeatPerCluster>{},
977  I1,
978  I1,
979  I1,
980  Number<B_K1>{}),
981  make_tuple(Number<B_K1>{},
982  Number<KPack / B_KRow>{},
983  Number<KPack / B_KRow * NRepeat>{},
984  I0,
985  I0,
986  I0,
987  I1));
988 
989  using AThreadCopy =
991  ComputeTypeA,
992  decltype(a_block_desc_k0_m0_m1_m2_k1),
993  decltype(a_thread_desc_),
994  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
996  6,
997  A_K1,
998  A_K1>;
999 
1000  using BThreadCopy =
1002  ComputeTypeB,
1003  decltype(b_block_desc_k0_n0_n1_n2_k1),
1004  decltype(b_thread_desc_),
1005  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
1007  6,
1008  B_K1,
1009  B_K1>;
1010 
1011  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
1012  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
1013  using Base::c_thread_desc_;
1014 };
1015 
1016 template <index_t BlockSize,
1017  typename ADataType,
1018  typename BDataType,
1019  typename ComputeTypeA,
1020  typename ComputeTypeB,
1021  typename AccDataType,
1022  typename AWmmaTileDesc,
1023  typename BWmmaTileDesc,
1024  index_t ABlockTransferSrcScalarPerVector,
1025  index_t BBlockTransferSrcScalarPerVector,
1026  index_t MPerBlock,
1027  index_t NPerBlock,
1028  index_t KPerBlock,
1029  index_t MPerWmma,
1030  index_t NPerWmma,
1031  index_t MRepeat,
1032  index_t NRepeat,
1033  index_t KPack,
1034  index_t KInner,
1035  bool TransposeC>
1037  BlockSize,
1038  ADataType,
1039  BDataType,
1040  ComputeTypeA,
1041  ComputeTypeB,
1042  AccDataType,
1043  AWmmaTileDesc,
1044  BWmmaTileDesc,
1045  ABlockTransferSrcScalarPerVector,
1046  BBlockTransferSrcScalarPerVector,
1047  MPerBlock,
1048  NPerBlock,
1049  KPerBlock,
1050  MPerWmma,
1051  NPerWmma,
1052  MRepeat,
1053  NRepeat,
1054  KPack,
1055  KInner,
1056  TransposeC,
1057  true>
1059  ADataType,
1060  BDataType,
1061  ComputeTypeA,
1062  ComputeTypeB,
1063  AccDataType,
1064  AWmmaTileDesc,
1065  BWmmaTileDesc,
1066  ABlockTransferSrcScalarPerVector,
1067  BBlockTransferSrcScalarPerVector,
1068  MPerBlock,
1069  NPerBlock,
1070  KPerBlock,
1071  MPerWmma,
1072  NPerWmma,
1073  MRepeat,
1074  NRepeat,
1075  KPack,
1076  KInner,
1077  TransposeC>
1078 {
1079  // GlobalPrefetchStages: 2
1080  // LocalPreFillStages: 1
1081  // LocalPreFetchStages: 1
1082  // LocalSharedMemoryBuffer: 1
1084  ADataType,
1085  BDataType,
1086  ComputeTypeA,
1087  ComputeTypeB,
1088  AccDataType,
1089  AWmmaTileDesc,
1090  BWmmaTileDesc,
1091  ABlockTransferSrcScalarPerVector,
1092  BBlockTransferSrcScalarPerVector,
1093  MPerBlock,
1094  NPerBlock,
1095  KPerBlock,
1096  MPerWmma,
1097  NPerWmma,
1098  MRepeat,
1099  NRepeat,
1100  KPack,
1101  KInner,
1102  TransposeC>;
1103  using Base::I0;
1104  using Base::I1;
1105  using Base::MWaves;
1106  using Base::WaveSize;
1107  using typename Base::HotLoopInstList;
1108 
1109  using Base::A_K1;
1110  using Base::A_KRow;
1111  using Base::B_K1;
1112  using Base::B_KRow;
1113  using Base::KRepeat;
1114  using Base::WmmaK;
1115 
1116  using Base::wmma_gemm;
1117 
1118  using Base::CalculateCThreadOriginDataIndex;
1119  using Base::
1120  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
1121  using Base::GetCThreadBuffer;
1122  using Base::
1123  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
1124 
1125  using Base::a_block_desc_k0_m0_m1_m2_k1;
1126  using Base::b_block_desc_k0_n0_n1_n2_k1;
1127 
1128  using typename Base::Empty;
1129 
1130  static constexpr index_t PrefetchStages = 2;
1131  static constexpr index_t PrefillStages = 1;
1132  static constexpr index_t GlobalBufferNum = 2;
1133 
1134  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
1135 
1137  {
1138  return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
1139  }
1140 
1141  __device__ static constexpr auto HotLoopScheduler()
1142  {
1143  constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
1144  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
1145  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
1146  constexpr auto wmma_interleave = 2;
1147  // B global
1149  ignore = i;
1150  if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
1151  {
1152  __builtin_amdgcn_sched_group_barrier(0x008, 2 * wmma_interleave, 0);
1153  }
1154  else
1155  {
1156  __builtin_amdgcn_sched_group_barrier(0x008, wmma_interleave, 0);
1157  }
1158  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
1159  });
1160 
1161  // A global
1163  ignore = i;
1164  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
1165  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
1166  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
1167  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
1168  });
1169 
1170  // A local
1171  static_for<0, num_ds_read_inst_a, 1>{}([&](auto i) {
1172  ignore = i;
1173  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
1174  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
1175  });
1176  }
1177 
1178  template <bool HasMainLoop,
1179  TailNumber TailNum,
1180  typename AGridDesc,
1181  typename ABlockDesc,
1182  typename ABlockTransfer,
1183  typename AGridBuffer,
1184  typename ABlockBuffer,
1185  typename ABlockTransferStep,
1186  typename BGridDesc,
1187  typename BBlockDesc,
1188  typename BBlockTransfer,
1189  typename BGridBuffer,
1190  typename BBlockBuffer,
1191  typename BBlockTransferStep,
1192  typename CThreadBuffer,
1193  typename AScaleStruct,
1194  typename BScaleStruct,
1195  typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
1196  __device__ void Run(const AGridDesc& a_grid_desc,
1197  const ABlockDesc& a_block_desc,
1198  ABlockTransfer& a_blockwise_copy,
1199  const AGridBuffer& a_grid_buf,
1200  ABlockBuffer& a_block_buf,
1201  const ABlockTransferStep& a_block_copy_step,
1202  const BGridDesc& b_grid_desc,
1203  const BBlockDesc&,
1204  BBlockTransfer& b_blockwise_copy,
1205  const BGridBuffer& b_grid_buf,
1206  BBlockBuffer&,
1207  const BBlockTransferStep& b_block_copy_step,
1208  CThreadBuffer& c_thread_buf,
1209  AScaleStruct&,
1210  BScaleStruct&,
1211  index_t num_loop,
1212  index_t) const
1213  {
1214  __builtin_amdgcn_sched_barrier(0);
1215  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
1216 
1217  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
1218  a_thread_desc_.GetElementSpaceSize());
1219  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
1220  b_thread_desc_.GetElementSpaceSize());
1221 
1222  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
1223  constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
1224 
1225  // Global prefetch A1 B1
1226  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
1227  b_blockwise_copy.Run(b_grid_desc,
1228  b_grid_buf,
1229  b_block_desc_k0_n0_n1_n2_k1,
1230  b_block_origin_idx,
1231  b_thread_bufs(I0));
1232 
1233  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1234  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1235  __builtin_amdgcn_sched_barrier(0);
1236 
1237  // Local prefill A1
1238  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1239 
1240  // Global prefetch A2
1241  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
1242  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1243 
1244  // Local prefetch A1
1245  block_sync_lds();
1246  static_for<0, MRepeat, 1>{}([&](auto m0) {
1247  static_for<0, KRepeat, 1>{}([&](auto k0) {
1248  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1249  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1250  a_block_buf,
1251  a_thread_desc_,
1252  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1253  a_thread_buf);
1254  });
1255  });
1256 
1257  // Initialize C
1258  c_thread_buf.Clear();
1259 
1260  __builtin_amdgcn_sched_barrier(0);
1261 
1262  // main body
1263  if constexpr(HasMainLoop)
1264  {
1265  index_t i = 0;
1266  do
1267  {
1268  auto LoopFunc = [&](auto wmma_reg_buf, auto local_read_buf) {
1269  b_blockwise_copy.Run(b_grid_desc,
1270  b_grid_buf,
1271  b_block_desc_k0_n0_n1_n2_k1,
1272  b_block_origin_idx,
1273  b_thread_bufs(local_read_buf));
1274 
1275  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1276 
1277  block_sync_lds();
1278 
1279  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
1280 
1281  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
1282  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1283 
1284  static_for<0, MRepeat, 1>{}([&](auto m0) {
1285  static_for<0, NRepeat, 1>{}([&](auto n0) {
1286  static_for<0, KRepeat, 1>{}([&](auto k0) {
1287  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1288  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1289  static_for<0, KInner, 1>{}([&](auto k_inner) {
1290  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1291  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1292  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1293  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1295  m0,
1296  k0,
1297  I0,
1298  I0,
1299  I0,
1300  Number<kk % A_K1>{}))>{}];
1301  });
1302  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1303  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1304  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1305  b_thread_bufs[wmma_reg_buf]
1306  [Number<b_thread_desc_.CalculateOffset(
1308  I0,
1309  I0,
1310  n0,
1311  I0,
1312  k0,
1313  Number<kk % B_K1>{}))>{}];
1314  });
1315  using wmma_input_type_a =
1316  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1317  using wmma_input_type_b =
1318  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1319 
1320  constexpr index_t c_offset =
1321  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
1322 
1323  wmma_gemm.Run(
1324  a_thread_vec.template AsType<wmma_input_type_a>(),
1325  b_thread_vec.template AsType<wmma_input_type_b>(),
1326  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1327  });
1328  });
1329  });
1330  });
1331 
1332  block_sync_lds();
1333 
1334  // loop prefetch copy
1335  static_for<0, MRepeat, 1>{}([&](auto m0) {
1336  static_for<0, KRepeat, 1>{}([&](auto k0) {
1337  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1338  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1339  a_block_buf,
1340  a_thread_desc_,
1341  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1342  a_thread_buf);
1343  });
1344  });
1345 
1346  HotLoopScheduler();
1347  __builtin_amdgcn_sched_barrier(0);
1348  };
1349 
1350  LoopFunc(I0, I1);
1351  LoopFunc(I1, I0);
1352 
1353  i += 2;
1354  } while(i < (num_loop - 2));
1355  }
1356 
1357  // tail
1358  if constexpr(TailNum == TailNumber::Even)
1359  {
1360  b_blockwise_copy.Run(b_grid_desc,
1361  b_grid_buf,
1362  b_block_desc_k0_n0_n1_n2_k1,
1363  b_block_origin_idx,
1364  b_thread_bufs(I1));
1365 
1366  block_sync_lds();
1367 
1368  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1369 
1370  static_for<0, MRepeat, 1>{}([&](auto m0) {
1371  static_for<0, NRepeat, 1>{}([&](auto n0) {
1372  static_for<0, KRepeat, 1>{}([&](auto k0) {
1373  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1374  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1375  static_for<0, KInner, 1>{}([&](auto k_inner) {
1376  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1377  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1378  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1379  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1381  m0,
1382  k0,
1383  I0,
1384  I0,
1385  I0,
1386  Number<kk % A_K1>{}))>{}];
1387  });
1388  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1389  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1390  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1391  b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
1393  I0,
1394  I0,
1395  n0,
1396  I0,
1397  k0,
1398  Number<kk % B_K1>{}))>{}];
1399  });
1400 
1401  using wmma_input_type_a =
1402  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1403  using wmma_input_type_b =
1404  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1405 
1406  constexpr index_t c_offset =
1407  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
1408 
1409  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1410  b_thread_vec.template AsType<wmma_input_type_b>(),
1411  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1412  });
1413  });
1414  });
1415  });
1416 
1417  block_sync_lds();
1418 
1419  // tail Local Prefetch A1
1420  static_for<0, MRepeat, 1>{}([&](auto m0) {
1421  static_for<0, KRepeat, 1>{}([&](auto k0) {
1422  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1423  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1424  a_block_buf,
1425  a_thread_desc_,
1426  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1427  a_thread_buf);
1428  });
1429  });
1430 
1431  __builtin_amdgcn_sched_barrier(0);
1432 
1433  static_for<0, MRepeat, 1>{}([&](auto m0) {
1434  static_for<0, NRepeat, 1>{}([&](auto n0) {
1435  static_for<0, KRepeat, 1>{}([&](auto k0) {
1436  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1437  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1438  static_for<0, KInner, 1>{}([&](auto k_inner) {
1439  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1440  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1441  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1442  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1444  m0,
1445  k0,
1446  I0,
1447  I0,
1448  I0,
1449  Number<kk % A_K1>{}))>{}];
1450  });
1451  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1452  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1453  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1454  b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
1456  I0,
1457  I0,
1458  n0,
1459  I0,
1460  k0,
1461  Number<kk % B_K1>{}))>{}];
1462  });
1463  using wmma_input_type_a =
1464  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1465  using wmma_input_type_b =
1466  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1467 
1468  constexpr index_t c_offset =
1469  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
1470 
1471  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1472  b_thread_vec.template AsType<wmma_input_type_b>(),
1473  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1474  });
1475  });
1476  });
1477  });
1478  // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
1479  // latency
1480  // __builtin_amdgcn_sched_barrier(0);
1481  }
1482  else if constexpr(TailNum == TailNumber::Odd)
1483  {
1484  static_for<0, MRepeat, 1>{}([&](auto m0) {
1485  static_for<0, NRepeat, 1>{}([&](auto n0) {
1486  static_for<0, KRepeat, 1>{}([&](auto k0) {
1487  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1488  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1489  static_for<0, KInner, 1>{}([&](auto k_inner) {
1490  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1491  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1492  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1493  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1495  m0,
1496  k0,
1497  I0,
1498  I0,
1499  I0,
1500  Number<kk % A_K1>{}))>{}];
1501  });
1502  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1503  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1504  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1505  b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
1507  I0,
1508  I0,
1509  n0,
1510  I0,
1511  k0,
1512  Number<kk % B_K1>{}))>{}];
1513  });
1514  using wmma_input_type_a =
1515  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1516  using wmma_input_type_b =
1517  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1518 
1519  constexpr index_t c_offset =
1520  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
1521 
1522  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1523  b_thread_vec.template AsType<wmma_input_type_b>(),
1524  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1525  });
1526  });
1527  });
1528  });
1529  }
1530  }
1531 
1532  template <bool HasMainLoop,
1533  TailNumber TailNum,
1534  typename AGridDesc,
1535  typename ABlockDesc,
1536  typename ABlockTransfer,
1537  typename AGridBuffer,
1538  typename ABlockBuffer,
1539  typename ABlockTransferStep,
1540  typename BGridDesc,
1541  typename BBlockDesc,
1542  typename BBlockTransfer,
1543  typename BGridBuffer,
1544  typename BBlockBuffer,
1545  typename BBlockTransferStep,
1546  typename CThreadBuffer,
1547  typename AScaleStruct,
1548  typename BScaleStruct,
1550  !ck::is_same_v<BScaleStruct, Empty>,
1551  bool>::type = false>
1552  __device__ void Run(const AGridDesc& a_grid_desc,
1553  const ABlockDesc& a_block_desc,
1554  ABlockTransfer& a_blockwise_copy,
1555  const AGridBuffer& a_grid_buf,
1556  ABlockBuffer& a_block_buf,
1557  const ABlockTransferStep& a_block_copy_step,
1558  const BGridDesc& b_grid_desc,
1559  const BBlockDesc&,
1560  BBlockTransfer& b_blockwise_copy,
1561  const BGridBuffer& b_grid_buf,
1562  BBlockBuffer&,
1563  const BBlockTransferStep& b_block_copy_step,
1564  CThreadBuffer& c_thread_buf,
1565  AScaleStruct& a_scale_struct,
1566  BScaleStruct& b_scale_struct,
1567  index_t num_loop,
1568  index_t num_loop_per_scale) const
1569  {
1570  __builtin_amdgcn_sched_barrier(0);
1571  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
1572  static constexpr auto NumScaleKBlock =
1573  Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
1574 
1575  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
1576  a_thread_desc_.GetElementSpaceSize());
1577  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
1578  b_thread_desc_.GetElementSpaceSize());
1579 
1580  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
1581  constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
1582 
1583  using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
1584  auto c_scale_struct = CScaleStruct{};
1585 
1586  auto gemm_core_func = [&](auto reg_buf) {
1587  static_for<0, MRepeat, 1>{}([&](auto m0) {
1588  static_for<0, NRepeat, 1>{}([&](auto n0) {
1589  static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
1590  c_scale_struct.Clear();
1591  static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
1592  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1593  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1594  static_for<0, KInner, 1>{}([&](auto k_inner) {
1595  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1596  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1597  constexpr index_t k_index =
1598  kscale0 * (KRepeat / NumScaleKBlock) + k0;
1599  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1600  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1602  m0,
1603  k_index,
1604  I0,
1605  I0,
1606  I0,
1607  Number<kk % A_K1>{}))>{}];
1608  });
1609  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1610  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1611  constexpr index_t k_index =
1612  kscale0 * (KRepeat / NumScaleKBlock) + k0;
1613  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1614  b_thread_bufs[reg_buf]
1615  [Number<b_thread_desc_.CalculateOffset(
1617  I0,
1618  I0,
1619  n0,
1620  I0,
1621  k_index,
1622  Number<kk % B_K1>{}))>{}];
1623  });
1624  using wmma_input_type_a =
1625  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1626  using wmma_input_type_b =
1627  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1628  wmma_gemm.Run(
1629  a_thread_vec.template AsType<wmma_input_type_a>(),
1630  b_thread_vec.template AsType<wmma_input_type_b>(),
1631  c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
1632  Number<0>{}));
1633  });
1634  });
1635  c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
1636  });
1637  });
1638  });
1639  };
1640 
1641  auto a_local_prefetch_func = [&]() {
1642  static_for<0, MRepeat, 1>{}([&](auto m0) {
1643  static_for<0, KRepeat, 1>{}([&](auto k0) {
1644  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1645  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1646  a_block_buf,
1647  a_thread_desc_,
1648  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1649  a_thread_buf);
1650  });
1651  });
1652  };
1653 
1654  // Global prefetch A1 B1
1655  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
1656  b_blockwise_copy.Run(b_grid_desc,
1657  b_grid_buf,
1658  b_block_desc_k0_n0_n1_n2_k1,
1659  b_block_origin_idx,
1660  b_thread_bufs(I0));
1661 
1662  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1663  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1664 
1665  // Scales global load
1666  a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
1667  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
1668 
1669  __builtin_amdgcn_sched_barrier(0);
1670 
1671  c_scale_struct.Load(a_scale_struct, b_scale_struct);
1672 
1673  // Local prefill A1
1674  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1675 
1676  // Global prefetch A2
1677  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
1678  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1679 
1680  // Local prefetch A1
1681  block_sync_lds();
1682  a_local_prefetch_func();
1683 
1684  // Initialize C
1685  c_thread_buf.Clear();
1686 
1687  __builtin_amdgcn_sched_barrier(0);
1688 
1689  // main body
1690  if constexpr(HasMainLoop)
1691  {
1692  index_t i = 0;
1693  do
1694  {
1695  auto LoopFunc = [&](auto wmma_reg_buf, auto local_read_buf) {
1696  b_blockwise_copy.Run(b_grid_desc,
1697  b_grid_buf,
1698  b_block_desc_k0_n0_n1_n2_k1,
1699  b_block_origin_idx,
1700  b_thread_bufs(local_read_buf));
1701 
1702  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1703 
1704  block_sync_lds();
1705 
1706  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
1707 
1708  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
1709  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1710 
1711  a_scale_struct.template GlobalLoad<0>(
1712  (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
1713  b_scale_struct.template GlobalLoad<0>(
1714  (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
1715 
1716  gemm_core_func(wmma_reg_buf);
1717 
1718  block_sync_lds();
1719 
1720  // loop prefetch copy
1721  a_local_prefetch_func();
1722 
1723  c_scale_struct.Load(a_scale_struct, b_scale_struct);
1724 
1725  // HotLoopScheduler();
1726  __builtin_amdgcn_sched_barrier(0);
1727  };
1728 
1729  LoopFunc(I0, I1);
1730  LoopFunc(I1, I0);
1731 
1732  i += 2;
1733  } while(i < (num_loop - 2));
1734  }
1735 
1736  // tail
1737  if constexpr(TailNum == TailNumber::Even)
1738  {
1739  b_blockwise_copy.Run(b_grid_desc,
1740  b_grid_buf,
1741  b_block_desc_k0_n0_n1_n2_k1,
1742  b_block_origin_idx,
1743  b_thread_bufs(I1));
1744 
1745  block_sync_lds();
1746 
1747  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1748 
1749  a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
1750  b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
1751 
1752  gemm_core_func(I0);
1753 
1754  block_sync_lds();
1755 
1756  // tail Local Prefetch A1
1757  a_local_prefetch_func();
1758 
1759  c_scale_struct.Load(a_scale_struct, b_scale_struct);
1760 
1761  __builtin_amdgcn_sched_barrier(0);
1762 
1763  gemm_core_func(I1);
1764  // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
1765  // latency
1766  // __builtin_amdgcn_sched_barrier(0);
1767  }
1768  else if constexpr(TailNum == TailNumber::Odd)
1769  {
1770  gemm_core_func(I0);
1771  }
1772  }
1773 
1774  protected:
1775  static constexpr auto b_thread_desc_ =
1777  I1,
1778  I1,
1779  Number<NRepeat>{},
1780  I1,
1781  Number<KRepeat>{},
1782  Number<B_K1>{}));
1783 
1784  using Base::a_thread_copy_;
1785  using Base::a_thread_desc_;
1786  using Base::c_thread_desc_;
1787 };
1788 
1789 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:211
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:270
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
@ Even
Even number of iterations.
@ Odd
Odd number of iterations.
@ Full
Full tail iterations.
@ Empty
No tail iterations.
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
@ Intrawave
Schedule within a single wavefront.
@ Interwave
Schedule across multiple wavefronts.
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:301
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:753
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:180
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &a_scale_struct, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:393
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &, index_t num_loop, index_t) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1196
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &a_scale_struct, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1552
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:35
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:11