/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp Source File
blockwise_gemm_pipeline_xdlops_v4.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimimal pipeline with highest resource request
11 // GlobalPrefetchStages: 3
12 // LocalPreFillStages: 2
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 2
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
36 struct BlockwiseGemmXdlops_pipeline_v4
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::I0;
122  using Base::I1;
123  using Base::KRepeat;
124  using Base::xdlops_gemm;
125  using typename Base::HotLoopInstList;
126 
127  using Base::CalculateCThreadOriginDataIndex;
128  using Base::CalculateCThreadOriginDataIndex8D;
129  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
130  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
131  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
132  using Base::GetCThreadBuffer;
133  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
134  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
136  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
137  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
138 
139  using Base::a_block_desc_m0_m1_m2_k;
140  using Base::b_block_desc_n0_n1_n2_k;
141 
142  using Base::AMmaKStride;
143  using Base::BMmaKStride;
144 
146 
147  static constexpr index_t PrefetchStages = 3;
148  static constexpr index_t PrefillStages = 2;
149  static constexpr index_t GlobalBufferNum = 1;
150  static constexpr index_t HotloopUnroll = 2;
151 
152  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
153  {
154  return num_loop > PrefetchStages;
155  }
156 
157  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
158  {
159  if(num_loop % HotloopUnroll == 1)
160  {
161  return TailNumber::Odd;
162  }
163  else
164  {
165  return TailNumber::Even;
166  }
167  }
168 
169  __device__ static constexpr void HotLoopScheduler()
170  {
171  // TODO: Take data type into consideration as pipe ver 3
172  // A-B splited schedule
173  constexpr auto num_ds_read_inst_a =
174  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
177  constexpr auto num_ds_read_inst_b =
178  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
181 
182  constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
183  constexpr auto num_dswrite_per_issue_a =
184  (HotLoopInstList::A_LDS_Write_Inst_Num + num_issue_a - 1) / num_issue_a;
185  constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
186 
187  constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
188  constexpr auto num_dswrite_per_issue_b =
189  (HotLoopInstList::B_LDS_Write_Inst_Num + num_issue_b - 1) / num_issue_b;
190  constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
191 
192  constexpr auto num_mfma_per_issue =
193  HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b);
194 
195  static_for<0, num_issue_a, 1>{}([&](auto i) {
196  ignore = i;
197  static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) {
198  ignore = idsread;
199  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
200  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
201  });
202 
203  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
204  ignore = idswrite;
205  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
206  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
207  });
208 
209  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
210  __builtin_amdgcn_sched_group_barrier(0x008,
211  num_mfma_per_issue - num_dsread_per_issue_a -
212  num_dswrite_per_issue_a,
213  0); // MFMA
214  });
215 
216  static_for<0, num_issue_b, 1>{}([&](auto i) {
217  ignore = i;
218  static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) {
219  ignore = idsread;
220  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
221  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
222  });
223 
224  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
225  ignore = idswrite;
226  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
227  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
228  });
229 
230  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
231  __builtin_amdgcn_sched_group_barrier(0x008,
232  num_mfma_per_issue - num_dsread_per_issue_a -
233  num_dswrite_per_issue_b,
234  0); // MFMA
235  });
236  __builtin_amdgcn_sched_barrier(0);
237  }
238 
239  template <bool HasMainLoop,
240  TailNumber TailNum,
241  typename AGridDesc,
242  typename ABlockDesc,
243  typename ABlockTransfer,
244  typename AGridBuffer,
245  typename ABlockBuffer,
246  typename ABlockTransferStep,
247  typename BGridDesc,
248  typename BBlockDesc,
249  typename BBlockTransfer,
250  typename BGridBuffer,
251  typename BBlockBuffer,
252  typename BBlockTransferStep,
253  typename CThreadBuffer>
254  __device__ void Run(const AGridDesc& a_grid_desc,
255  const ABlockDesc& a_block_desc,
256  ABlockTransfer& a_blockwise_copy,
257  const AGridBuffer& a_grid_buf,
258  ABlockBuffer& a_block_buf,
259  const ABlockTransferStep& a_block_copy_step,
260  const BGridDesc& b_grid_desc,
261  const BBlockDesc& b_block_desc,
262  BBlockTransfer& b_blockwise_copy,
263  const BGridBuffer& b_grid_buf,
264  BBlockBuffer& b_block_buf,
265  const BBlockTransferStep& b_block_copy_step,
266  CThreadBuffer& c_thread_buf,
267  index_t num_loop) const
268  {
269  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
270  a_thread_desc_.GetElementSpaceSize());
271  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
272  b_thread_desc_.GetElementSpaceSize());
273 
274  StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
275  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
276 
277  // Global prefetch 1
278  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
279  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
280 
281  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
282  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
283 
284  // Local prefill 1
285  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
286  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
287 
288  // Local prefetch 1
289  block_sync_lds();
290  static_for<0, KRepeat, 1>{}([&](auto k) {
291  static_for<0, MRepeat, 1>{}([&](auto m0) {
294  a_block_buf.At(I0),
296  make_tuple(m0, I0, k, I0),
297  a_thread_bufs(I0));
298  });
299  static_for<0, NRepeat, 1>{}([&](auto n0) {
302  b_block_buf.At(I0),
304  make_tuple(n0, I0, k, I0),
305  b_thread_bufs(I0));
306  });
307  });
308 
309  // Global prefetch 2
310  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
311  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
312 
313  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
314  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
315 
316  // Local prefill 2
317  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
318  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
319 
320  // Global prefetch 3
321  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
322  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
323 
324  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
325  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
326 
327  // Initialize C
328  c_thread_buf.Clear();
329 
330  // main body
331  if constexpr(HasMainLoop)
332  {
333  index_t i = 0;
334  // This hot loop has two legacy loopover, to implement the double local buffer strategy
335  do
336  {
337  auto LoopFunc = [&](auto lds_read_buf,
338  auto lds_read_reg_buf,
339  auto lds_write_buf,
340  auto mfma_reg_buf) {
341  block_sync_lds();
342 
343  static_for<0, KRepeat, 1>{}([&](auto k) {
344  static_for<0, MRepeat, 1>{}([&](auto m0) {
347  a_block_buf.At(lds_read_buf),
349  make_tuple(m0, I0, k, I0),
350  a_thread_bufs(lds_read_reg_buf));
351  });
352  static_for<0, NRepeat, 1>{}([&](auto n0) {
355  b_block_buf.At(lds_read_buf),
357  make_tuple(n0, I0, k, I0),
358  b_thread_bufs(lds_read_reg_buf));
359  });
360  });
361 
362  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
363  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
364 
365  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
366  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
367 
368  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
369  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
370 
371  static_for<0, KRepeat, 1>{}([&](auto k0) {
372  static_for<0, MRepeat, 1>{}([&](auto m0) {
373  static_for<0, NRepeat, 1>{}([&](auto n0) {
376 
377  static_for<0, KPack, 1>{}([&](auto ik) {
378  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
379  a_thread_bufs[mfma_reg_buf]
380  [Number<a_thread_desc_.CalculateOffset(
381  make_tuple(m0, I0, k0, ik))>{}];
382  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
383  b_thread_bufs[mfma_reg_buf]
384  [Number<b_thread_desc_.CalculateOffset(
385  make_tuple(n0, I0, k0, ik))>{}];
386  });
387 
388  using mfma_input_type =
390  xdlops_gemm.K1PerXdlops>::type;
391 
392  constexpr index_t c_offset =
393  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
394 
395  xdlops_gemm.Run(
396  a_thread_vec.template AsType<mfma_input_type>(),
397  b_thread_vec.template AsType<mfma_input_type>(),
398  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
399  });
400  });
401  });
402 
404  };
405 
406  LoopFunc(I1, I1, I0, I0);
407  LoopFunc(I0, I0, I1, I1);
408 
409  i += HotloopUnroll;
410  } while(i < (num_loop - PrefetchStages));
411  }
412 
413  auto ReadWriteCompFunc = [&](auto lds_read_buf,
414  auto lds_read_reg_buf,
415  auto lds_write_buf,
416  auto mfma_reg_buf) {
417  block_sync_lds();
418 
419  static_for<0, KRepeat, 1>{}([&](auto k) {
420  static_for<0, MRepeat, 1>{}([&](auto m0) {
423  a_block_buf.At(lds_read_buf),
425  make_tuple(m0, I0, k, I0),
426  a_thread_bufs(lds_read_reg_buf));
427  });
428  static_for<0, NRepeat, 1>{}([&](auto n0) {
431  b_block_buf.At(lds_read_buf),
433  make_tuple(n0, I0, k, I0),
434  b_thread_bufs(lds_read_reg_buf));
435  });
436  });
437 
438  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
439  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
440 
441  static_for<0, KRepeat, 1>{}([&](auto k0) {
442  static_for<0, MRepeat, 1>{}([&](auto m0) {
443  static_for<0, NRepeat, 1>{}([&](auto n0) {
446 
447  static_for<0, KPack, 1>{}([&](auto ik) {
448  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
449  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
450  make_tuple(m0, I0, k0, ik))>{}];
451  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
452  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
453  make_tuple(n0, I0, k0, ik))>{}];
454  });
455 
456  using mfma_input_type =
457  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
458 
459  constexpr index_t c_offset =
460  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
461 
462  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
463  b_thread_vec.template AsType<mfma_input_type>(),
464  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
465  });
466  });
467  });
468 
470  };
471 
472  auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) {
473  block_sync_lds();
474 
475  static_for<0, KRepeat, 1>{}([&](auto k) {
476  static_for<0, MRepeat, 1>{}([&](auto m0) {
479  a_block_buf.At(lds_read_buf),
481  make_tuple(m0, I0, k, I0),
482  a_thread_bufs(lds_read_reg_buf));
483  });
484  static_for<0, NRepeat, 1>{}([&](auto n0) {
487  b_block_buf.At(lds_read_buf),
489  make_tuple(n0, I0, k, I0),
490  b_thread_bufs(lds_read_reg_buf));
491  });
492  });
493 
494  static_for<0, KRepeat, 1>{}([&](auto k0) {
495  static_for<0, MRepeat, 1>{}([&](auto m0) {
496  static_for<0, NRepeat, 1>{}([&](auto n0) {
499 
500  static_for<0, KPack, 1>{}([&](auto ik) {
501  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
502  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
503  make_tuple(m0, I0, k0, ik))>{}];
504  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
505  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
506  make_tuple(n0, I0, k0, ik))>{}];
507  });
508 
509  using mfma_input_type =
510  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
511 
512  constexpr index_t c_offset =
513  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
514 
515  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
516  b_thread_vec.template AsType<mfma_input_type>(),
517  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
518  });
519  });
520  });
521 
523  };
524 
525  auto CompFunc = [&](auto mfma_reg_buf) {
526  static_for<0, KRepeat, 1>{}([&](auto k0) {
527  static_for<0, MRepeat, 1>{}([&](auto m0) {
528  static_for<0, NRepeat, 1>{}([&](auto n0) {
531 
532  static_for<0, KPack, 1>{}([&](auto ik) {
533  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
534  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
535  make_tuple(m0, I0, k0, ik))>{}];
536  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
537  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
538  make_tuple(n0, I0, k0, ik))>{}];
539  });
540 
541  using mfma_input_type =
542  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
543 
544  constexpr index_t c_offset =
545  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
546 
547  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
548  b_thread_vec.template AsType<mfma_input_type>(),
549  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
550  });
551  });
552  });
553  };
554  // tail
555  if constexpr(TailNum == TailNumber::Odd)
556  {
557  ReadWriteCompFunc(I1, I1, I0, I0);
558  ReadCompFunc(I0, I0, I1);
559  CompFunc(I0);
560  }
561  else if constexpr(TailNum == TailNumber::Even)
562  {
563  ReadCompFunc(I1, I1, I0);
564  CompFunc(I1);
565  }
566  }
567 
568  protected:
569  using Base::a_thread_copy_;
570  using Base::a_thread_desc_;
571  using Base::b_thread_copy_;
572  using Base::b_thread_desc_;
573  using Base::c_thread_desc_;
574 };
575 
576 // Compute optimimal pipeline with highest resource request
577 // Implementation with direct load
578 // GlobalPrefetchStages: 3
579 // LocalPreFillStages: 2
580 // LocalPreFetchStages: 1
581 // LocalSharedMemoryBuffer: 2
582 
583 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
584  index_t BlockSize,
585  typename ADataType,
586  typename BDataType,
587  typename ComputeDataType,
588  typename AccDataType,
589  typename ATileDesc,
590  typename BTileDesc,
591  typename AMmaTileDesc,
592  typename BMmaTileDesc,
593  index_t ABlockTransferSrcScalarPerVector,
594  index_t BBlockTransferSrcScalarPerVector,
595  index_t MPerBlock,
596  index_t NPerBlock,
597  index_t KPerBlock,
598  index_t MPerXDL,
599  index_t NPerXDL,
600  index_t MRepeat,
601  index_t NRepeat,
602  index_t KPacks>
604 {
605 };
606 
607 template <index_t BlockSize,
608  typename ADataType,
609  typename BDataType,
610  typename ComputeDataType,
611  typename AccDataType,
612  typename ATileDesc,
613  typename BTileDesc,
614  typename AMmaTileDesc,
615  typename BMmaTileDesc,
616  index_t ABlockTransferSrcScalarPerVector,
617  index_t BBlockTransferSrcScalarPerVector,
618  index_t MPerBlock,
619  index_t NPerBlock,
620  index_t KPerBlock,
621  index_t MPerXDL,
622  index_t NPerXDL,
623  index_t MRepeat,
624  index_t NRepeat,
625  index_t KPack
626  // ,bool TransposeC //disable transposec right now...
627  >
629  BlockSize,
630  ADataType,
631  BDataType,
632  ComputeDataType,
633  AccDataType,
634  ATileDesc,
635  BTileDesc,
636  AMmaTileDesc,
637  BMmaTileDesc,
638  ABlockTransferSrcScalarPerVector,
639  BBlockTransferSrcScalarPerVector,
640  MPerBlock,
641  NPerBlock,
642  KPerBlock,
643  MPerXDL,
644  NPerXDL,
645  MRepeat,
646  NRepeat,
647  KPack>
649  ADataType,
650  BDataType,
651  ComputeDataType,
652  AccDataType,
653  ATileDesc,
654  BTileDesc,
655  AMmaTileDesc,
656  BMmaTileDesc,
657  ABlockTransferSrcScalarPerVector,
658  BBlockTransferSrcScalarPerVector,
659  MPerBlock,
660  NPerBlock,
661  KPerBlock,
662  MPerXDL,
663  NPerXDL,
664  MRepeat,
665  NRepeat,
666  KPack>
667 
668 {
670  ADataType,
671  BDataType,
672  ComputeDataType,
673  AccDataType,
674  ATileDesc,
675  BTileDesc,
676  AMmaTileDesc,
677  BMmaTileDesc,
678  ABlockTransferSrcScalarPerVector,
679  BBlockTransferSrcScalarPerVector,
680  MPerBlock,
681  NPerBlock,
682  KPerBlock,
683  MPerXDL,
684  NPerXDL,
685  MRepeat,
686  NRepeat,
687  KPack>;
688  using Base::I0;
689  using Base::I1;
690  using Base::KRepeat;
691  using Base::xdlops_gemm;
692  using typename Base::HotLoopInstList;
693 
694  using Base::CalculateCThreadOriginDataIndex;
695  using Base::CalculateCThreadOriginDataIndex8D;
696  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
697  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
698  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
699  using Base::GetCThreadBuffer;
700  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
701  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
702  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
703  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
704  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
705 
706  using Base::a_block_desc_m0_m1_m2_k;
707  using Base::b_block_desc_n0_n1_n2_k;
708 
709  using Base::AMmaKStride;
710  using Base::BMmaKStride;
711 
713 
714  static constexpr index_t PrefetchStages = 2;
715  static constexpr index_t PrefillStages = 2;
716  static constexpr index_t GlobalBufferNum = 1;
717  static constexpr index_t HotloopUnroll = 2;
718 
719  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
720  {
721  return num_loop > PrefetchStages;
722  }
723 
724  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
725  {
726  if(num_loop % HotloopUnroll == 1)
727  {
728  return TailNumber::Odd;
729  }
730  else
731  {
732  return TailNumber::Even;
733  }
734  }
735 
736  __device__ static constexpr void HotLoopScheduler()
737  {
738  // TODO: Take data type into consideration as pipe ver 3
739  // A-B splited schedule
740  constexpr auto num_ds_read_inst_a =
741  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
742  ? HotLoopInstList::A_LDS_Read_Inst_Num
743  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
744  constexpr auto num_ds_read_inst_b =
745  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
746  ? HotLoopInstList::B_LDS_Read_Inst_Num
747  : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
748 
749  constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
750  constexpr auto num_dswrite_per_issue_a = 0;
751  constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
752 
753  constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
754  constexpr auto num_dswrite_per_issue_b = 0;
755  constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
756 
757  constexpr auto num_mfma_per_issue =
758  HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b);
759 
760  static_for<0, num_issue_a, 1>{}([&](auto i) {
761  ignore = i;
762  static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) {
763  ignore = idsread;
764  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
765  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
766  });
767 
768  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
769  ignore = idswrite;
770  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
771  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
772  });
773 
774  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
775  __builtin_amdgcn_sched_group_barrier(0x008,
776  num_mfma_per_issue - num_dsread_per_issue_a -
777  num_dswrite_per_issue_a,
778  0); // MFMA
779  });
780 
781  static_for<0, num_issue_b, 1>{}([&](auto i) {
782  ignore = i;
783  static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) {
784  ignore = idsread;
785  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
786  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
787  });
788 
789  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
790  ignore = idswrite;
791  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
792  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
793  });
794 
795  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
796  __builtin_amdgcn_sched_group_barrier(0x008,
797  num_mfma_per_issue - num_dsread_per_issue_a -
798  num_dswrite_per_issue_b,
799  0); // MFMA
800  });
801  __builtin_amdgcn_sched_barrier(0);
802  }
803 
804  template <bool HasMainLoop,
805  TailNumber TailNum,
806  typename AGridDesc,
807  typename ABlockDesc,
808  typename ABlockTransfer,
809  typename AGridBuffer,
810  typename ABlockBuffer,
811  typename ABlockTransferStep,
812  typename BGridDesc,
813  typename BBlockDesc,
814  typename BBlockTransfer,
815  typename BGridBuffer,
816  typename BBlockBuffer,
817  typename BBlockTransferStep,
818  typename CThreadBuffer>
819  __device__ void Run(const AGridDesc& a_grid_desc,
820  const ABlockDesc& a_block_desc,
821  ABlockTransfer& a_blockwise_copy,
822  const AGridBuffer& a_grid_buf,
823  ABlockBuffer& a_block_buf,
824  const ABlockTransferStep& a_block_copy_step,
825  const BGridDesc& b_grid_desc,
826  const BBlockDesc& b_block_desc,
827  BBlockTransfer& b_blockwise_copy,
828  const BGridBuffer& b_grid_buf,
829  BBlockBuffer& b_block_buf,
830  const BBlockTransferStep& b_block_copy_step,
831  CThreadBuffer& c_thread_buf,
832  index_t num_loop) const
833  {
834  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
835  a_thread_desc_.GetElementSpaceSize());
836  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
837  b_thread_desc_.GetElementSpaceSize());
838 
839  StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
840  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
841 
842  // Global prefetch 1
843  a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(I0));
844  b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(I0));
845 
846  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
847  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
848 
850 
851  // Local prefetch 1
852  static_for<0, KRepeat, 1>{}([&](auto k) {
853  static_for<0, MRepeat, 1>{}([&](auto m0) {
854  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
855  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
856  a_block_buf.At(I0),
857  a_thread_desc_,
858  make_tuple(m0, I0, k, I0),
859  a_thread_bufs(I0));
860  });
861  static_for<0, NRepeat, 1>{}([&](auto n0) {
862  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
863  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
864  b_block_buf.At(I0),
865  b_thread_desc_,
866  make_tuple(n0, I0, k, I0),
867  b_thread_bufs(I0));
868  });
869  });
870 
871  // Global prefetch 2
872  a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(I1));
873  b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(I1));
874 
875  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
876  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
877 
878  // Initialize C
879  c_thread_buf.Clear();
880 
881  // main body
882  if constexpr(HasMainLoop)
883  {
884  index_t i = 0;
885  // This hot loop has two legacy loopover, to implement the double local buffer strategy
886  do
887  {
888  auto LoopFunc = [&](auto lds_read_buf,
889  auto lds_read_reg_buf,
890  auto lds_write_buf,
891  auto mfma_reg_buf) {
893 
894  static_for<0, KRepeat, 1>{}([&](auto k) {
895  static_for<0, MRepeat, 1>{}([&](auto m0) {
896  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
897  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
898  a_block_buf.At(lds_read_buf),
899  a_thread_desc_,
900  make_tuple(m0, I0, k, I0),
901  a_thread_bufs(lds_read_reg_buf));
902  });
903  static_for<0, NRepeat, 1>{}([&](auto n0) {
904  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
905  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
906  b_block_buf.At(lds_read_buf),
907  b_thread_desc_,
908  make_tuple(n0, I0, k, I0),
909  b_thread_bufs(lds_read_reg_buf));
910  });
911  });
912 
913  a_blockwise_copy.Run(
914  a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf));
915  b_blockwise_copy.Run(
916  b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf));
917 
918  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
919  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
920 
921  static_for<0, KRepeat, 1>{}([&](auto k0) {
922  static_for<0, MRepeat, 1>{}([&](auto m0) {
923  static_for<0, NRepeat, 1>{}([&](auto n0) {
926 
927  static_for<0, KPack, 1>{}([&](auto ik) {
928  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
929  a_thread_bufs[mfma_reg_buf]
930  [Number<a_thread_desc_.CalculateOffset(
931  make_tuple(m0, I0, k0, ik))>{}];
932  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
933  b_thread_bufs[mfma_reg_buf]
934  [Number<b_thread_desc_.CalculateOffset(
935  make_tuple(n0, I0, k0, ik))>{}];
936  });
937 
938  using mfma_input_type =
940  xdlops_gemm.K1PerXdlops>::type;
941 
942  constexpr index_t c_offset =
943  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
944 
945  xdlops_gemm.Run(
946  a_thread_vec.template AsType<mfma_input_type>(),
947  b_thread_vec.template AsType<mfma_input_type>(),
948  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
949  });
950  });
951  });
952 
953  HotLoopScheduler();
954  };
955 
956  LoopFunc(I1, I1, I0, I0);
957  LoopFunc(I0, I0, I1, I1);
958 
959  i += HotloopUnroll;
960  } while(i < (num_loop - PrefetchStages));
961  }
962 
963  auto ReadWriteCompFunc = [&](auto lds_read_buf,
964  auto lds_read_reg_buf,
965  auto lds_write_buf,
966  auto mfma_reg_buf) {
968 
969  static_for<0, KRepeat, 1>{}([&](auto k) {
970  static_for<0, MRepeat, 1>{}([&](auto m0) {
971  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
972  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
973  a_block_buf.At(lds_read_buf),
974  a_thread_desc_,
975  make_tuple(m0, I0, k, I0),
976  a_thread_bufs(lds_read_reg_buf));
977  });
978  static_for<0, NRepeat, 1>{}([&](auto n0) {
979  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
980  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
981  b_block_buf.At(lds_read_buf),
982  b_thread_desc_,
983  make_tuple(n0, I0, k, I0),
984  b_thread_bufs(lds_read_reg_buf));
985  });
986  });
987 
988  a_blockwise_copy.Run(
989  a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf));
990  b_blockwise_copy.Run(
991  b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf));
992 
993  static_for<0, KRepeat, 1>{}([&](auto k0) {
994  static_for<0, MRepeat, 1>{}([&](auto m0) {
995  static_for<0, NRepeat, 1>{}([&](auto n0) {
998 
999  static_for<0, KPack, 1>{}([&](auto ik) {
1000  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1001  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
1002  make_tuple(m0, I0, k0, ik))>{}];
1003  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1004  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
1005  make_tuple(n0, I0, k0, ik))>{}];
1006  });
1007 
1008  using mfma_input_type =
1009  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
1010 
1011  constexpr index_t c_offset =
1012  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1013 
1014  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
1015  b_thread_vec.template AsType<mfma_input_type>(),
1016  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1017  });
1018  });
1019  });
1020 
1021  HotLoopScheduler();
1022  };
1023 
1024  auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) {
1026 
1027  static_for<0, KRepeat, 1>{}([&](auto k) {
1028  static_for<0, MRepeat, 1>{}([&](auto m0) {
1029  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
1030  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
1031  a_block_buf.At(lds_read_buf),
1032  a_thread_desc_,
1033  make_tuple(m0, I0, k, I0),
1034  a_thread_bufs(lds_read_reg_buf));
1035  });
1036  static_for<0, NRepeat, 1>{}([&](auto n0) {
1037  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
1038  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
1039  b_block_buf.At(lds_read_buf),
1040  b_thread_desc_,
1041  make_tuple(n0, I0, k, I0),
1042  b_thread_bufs(lds_read_reg_buf));
1043  });
1044  });
1045 
1046  static_for<0, KRepeat, 1>{}([&](auto k0) {
1047  static_for<0, MRepeat, 1>{}([&](auto m0) {
1048  static_for<0, NRepeat, 1>{}([&](auto n0) {
1051 
1052  static_for<0, KPack, 1>{}([&](auto ik) {
1053  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1054  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
1055  make_tuple(m0, I0, k0, ik))>{}];
1056  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1057  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
1058  make_tuple(n0, I0, k0, ik))>{}];
1059  });
1060 
1061  using mfma_input_type =
1062  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
1063 
1064  constexpr index_t c_offset =
1065  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1066 
1067  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
1068  b_thread_vec.template AsType<mfma_input_type>(),
1069  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1070  });
1071  });
1072  });
1073 
1074  HotLoopScheduler();
1075  };
1076 
1077  auto CompFunc = [&](auto mfma_reg_buf) {
1078  static_for<0, KRepeat, 1>{}([&](auto k0) {
1079  static_for<0, MRepeat, 1>{}([&](auto m0) {
1080  static_for<0, NRepeat, 1>{}([&](auto n0) {
1083 
1084  static_for<0, KPack, 1>{}([&](auto ik) {
1085  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1086  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
1087  make_tuple(m0, I0, k0, ik))>{}];
1088  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1089  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
1090  make_tuple(n0, I0, k0, ik))>{}];
1091  });
1092 
1093  using mfma_input_type =
1094  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
1095 
1096  constexpr index_t c_offset =
1097  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1098 
1099  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
1100  b_thread_vec.template AsType<mfma_input_type>(),
1101  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1102  });
1103  });
1104  });
1105  };
1106  // tail
1107  if constexpr(TailNum == TailNumber::Odd)
1108  {
1109  ReadWriteCompFunc(I1, I1, I0, I0);
1110  ReadCompFunc(I0, I0, I1);
1111  CompFunc(I0);
1112  }
1113  else if constexpr(TailNum == TailNumber::Even)
1114  {
1115  ReadCompFunc(I1, I1, I0);
1116  CompFunc(I1);
1117  }
1118  }
1119 
1120  protected:
1121  using Base::a_thread_copy_;
1122  using Base::a_thread_desc_;
1123  using Base::b_thread_copy_;
1124  using Base::b_thread_desc_;
1125  using Base::c_thread_desc_;
1126 };
1127 
1128 } // namespace ck
Definition: ck.hpp:268
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__device__ void block_sync_lds_direct_load()
Definition: synchronization.hpp:43
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
static constexpr index_t A_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:82
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
static constexpr index_t B_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:83
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:254
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops.hpp:105
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:961
static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:373
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:967
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:991
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops.hpp:104
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:453
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:990
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:454
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:955
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops.hpp:120
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:819
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:604
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10