/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp Source File
blockwise_gemm_pipeline_xdlops_v4.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimimal pipeline with highest resource request
11 // GlobalPrefetchStages: 3
12 // LocalPreFillStages: 2
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 2
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
36 struct BlockwiseGemmXdlops_pipeline_v4
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::I0;
122  using Base::I1;
123  using Base::KRepeat;
124  using Base::xdlops_gemm;
125  using typename Base::HotLoopInstList;
126 
127  using Base::CalculateCThreadOriginDataIndex;
128  using Base::CalculateCThreadOriginDataIndex8D;
129  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
130  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
131  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
132  using Base::GetCThreadBuffer;
133  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
134  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
136  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
137  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
138 
139  using Base::a_block_desc_m0_m1_m2_k;
140  using Base::b_block_desc_n0_n1_n2_k;
141 
142  using Base::AMmaKStride;
143  using Base::BMmaKStride;
144 
146 
147  static constexpr index_t PrefetchStages = 3;
148  static constexpr index_t PrefillStages = 2;
149  static constexpr index_t GlobalBufferNum = 1;
150  static constexpr index_t HotloopUnroll = 2;
151 
152  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
153  {
154  return num_loop > PrefetchStages;
155  }
156 
157  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
158  {
159  if(num_loop % HotloopUnroll == 1)
160  {
161  return TailNumber::Odd;
162  }
163  else
164  {
165  return TailNumber::Even;
166  }
167  }
168 
169  __device__ static constexpr void HotLoopScheduler()
170  {
171  // TODO: Take data type into consideration as pipe ver 3
172  // A-B splited schedule
173  constexpr auto num_ds_read_inst_a =
174  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
177  constexpr auto num_ds_read_inst_b =
178  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
181 
182  constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
183  constexpr auto num_dswrite_per_issue_a =
184  (HotLoopInstList::A_LDS_Write_Inst_Num + num_issue_a - 1) / num_issue_a;
185  constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
186 
187  constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
188  constexpr auto num_dswrite_per_issue_b =
189  (HotLoopInstList::B_LDS_Write_Inst_Num + num_issue_b - 1) / num_issue_b;
190  constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
191 
192  constexpr auto num_mfma_per_issue =
193  HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b);
194 
195  static_for<0, num_issue_a, 1>{}([&](auto i) {
196  ignore = i;
197  static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) {
198  ignore = idsread;
199  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
200  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
201  });
202 
203  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
204  ignore = idswrite;
205  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
206  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
207  });
208 
209  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
210  __builtin_amdgcn_sched_group_barrier(0x008,
211  num_mfma_per_issue - num_dsread_per_issue_a -
212  num_dswrite_per_issue_a,
213  0); // MFMA
214  });
215 
216  static_for<0, num_issue_b, 1>{}([&](auto i) {
217  ignore = i;
218  static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) {
219  ignore = idsread;
220  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
221  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
222  });
223 
224  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
225  ignore = idswrite;
226  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
227  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
228  });
229 
230  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
231  __builtin_amdgcn_sched_group_barrier(0x008,
232  num_mfma_per_issue - num_dsread_per_issue_a -
233  num_dswrite_per_issue_b,
234  0); // MFMA
235  });
236  __builtin_amdgcn_sched_barrier(0);
237  }
238 
239  template <bool HasMainLoop,
240  TailNumber TailNum,
241  typename AGridDesc,
242  typename ABlockDesc,
243  typename ABlockTransfer,
244  typename AGridBuffer,
245  typename ABlockBuffer,
246  typename ABlockTransferStep,
247  typename BGridDesc,
248  typename BBlockDesc,
249  typename BBlockTransfer,
250  typename BGridBuffer,
251  typename BBlockBuffer,
252  typename BBlockTransferStep,
253  typename CThreadBuffer>
254  __device__ void Run(const AGridDesc& a_grid_desc,
255  const ABlockDesc& a_block_desc,
256  ABlockTransfer& a_blockwise_copy,
257  const AGridBuffer& a_grid_buf,
258  ABlockBuffer& a_block_buf,
259  const ABlockTransferStep& a_block_copy_step,
260  const BGridDesc& b_grid_desc,
261  const BBlockDesc& b_block_desc,
262  BBlockTransfer& b_blockwise_copy,
263  const BGridBuffer& b_grid_buf,
264  BBlockBuffer& b_block_buf,
265  const BBlockTransferStep& b_block_copy_step,
266  CThreadBuffer& c_thread_buf,
267  index_t num_loop) const
268  {
269  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
270  a_thread_desc_.GetElementSpaceSize());
271  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
272  b_thread_desc_.GetElementSpaceSize());
273 
274  StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
275  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
276 
277  // Global prefetch 1
278  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
279  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
280 
281  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
282  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
283 
284  // Local prefill 1
285  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
286  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
287 
288  // Local prefetch 1
289  block_sync_lds();
290  static_for<0, KRepeat, 1>{}([&](auto k) {
291  static_for<0, MRepeat, 1>{}([&](auto m0) {
294  a_block_buf.At(I0),
296  make_tuple(m0, I0, k, I0),
297  a_thread_bufs(I0));
298  });
299  static_for<0, NRepeat, 1>{}([&](auto n0) {
302  b_block_buf.At(I0),
304  make_tuple(n0, I0, k, I0),
305  b_thread_bufs(I0));
306  });
307  });
308 
309  // Global prefetch 2
310  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
311  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
312 
313  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
314  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
315 
316  // Local prefill 2
317  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
318  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
319 
320  // Global prefetch 3
321  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
322  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
323 
324  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
325  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
326 
327  // Initialize C
328  c_thread_buf.Clear();
329 
330  // main body
331  if constexpr(HasMainLoop)
332  {
333  index_t i = 0;
334  // This hot loop has two legacy loopover, to implement the double local buffer strategy
335  do
336  {
337  auto LoopFunc = [&](auto lds_read_buf,
338  auto lds_read_reg_buf,
339  auto lds_write_buf,
340  auto mfma_reg_buf) {
341  block_sync_lds();
342 
343  static_for<0, KRepeat, 1>{}([&](auto k) {
344  static_for<0, MRepeat, 1>{}([&](auto m0) {
347  a_block_buf.At(lds_read_buf),
349  make_tuple(m0, I0, k, I0),
350  a_thread_bufs(lds_read_reg_buf));
351  });
352  static_for<0, NRepeat, 1>{}([&](auto n0) {
355  b_block_buf.At(lds_read_buf),
357  make_tuple(n0, I0, k, I0),
358  b_thread_bufs(lds_read_reg_buf));
359  });
360  });
361 
362  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
363  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
364 
365  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
366  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
367 
368  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
369  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
370 
371  static_for<0, KRepeat, 1>{}([&](auto k0) {
372  static_for<0, MRepeat, 1>{}([&](auto m0) {
373  static_for<0, NRepeat, 1>{}([&](auto n0) {
376 
377  static_for<0, KPack, 1>{}([&](auto ik) {
378  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
379  a_thread_bufs[mfma_reg_buf]
380  [Number<a_thread_desc_.CalculateOffset(
381  make_tuple(m0, I0, k0, ik))>{}];
382  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
383  b_thread_bufs[mfma_reg_buf]
384  [Number<b_thread_desc_.CalculateOffset(
385  make_tuple(n0, I0, k0, ik))>{}];
386  });
387 
388  using mfma_input_type =
390  xdlops_gemm.K1PerXdlops>::type;
391 
392  constexpr index_t c_offset =
393  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
394 
395  xdlops_gemm.Run(
396  a_thread_vec.template AsType<mfma_input_type>(),
397  b_thread_vec.template AsType<mfma_input_type>(),
398  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
399  });
400  });
401  });
402 
404  };
405 
406  LoopFunc(I1, I1, I0, I0);
407  LoopFunc(I0, I0, I1, I1);
408 
409  i += HotloopUnroll;
410  } while(i < (num_loop - PrefetchStages));
411  }
412 
413  auto ReadWriteCompFunc = [&](auto lds_read_buf,
414  auto lds_read_reg_buf,
415  auto lds_write_buf,
416  auto mfma_reg_buf) {
417  block_sync_lds();
418 
419  static_for<0, KRepeat, 1>{}([&](auto k) {
420  static_for<0, MRepeat, 1>{}([&](auto m0) {
423  a_block_buf.At(lds_read_buf),
425  make_tuple(m0, I0, k, I0),
426  a_thread_bufs(lds_read_reg_buf));
427  });
428  static_for<0, NRepeat, 1>{}([&](auto n0) {
431  b_block_buf.At(lds_read_buf),
433  make_tuple(n0, I0, k, I0),
434  b_thread_bufs(lds_read_reg_buf));
435  });
436  });
437 
438  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
439  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
440 
441  static_for<0, KRepeat, 1>{}([&](auto k0) {
442  static_for<0, MRepeat, 1>{}([&](auto m0) {
443  static_for<0, NRepeat, 1>{}([&](auto n0) {
446 
447  static_for<0, KPack, 1>{}([&](auto ik) {
448  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
449  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
450  make_tuple(m0, I0, k0, ik))>{}];
451  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
452  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
453  make_tuple(n0, I0, k0, ik))>{}];
454  });
455 
456  using mfma_input_type =
457  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
458 
459  constexpr index_t c_offset =
460  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
461 
462  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
463  b_thread_vec.template AsType<mfma_input_type>(),
464  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
465  });
466  });
467  });
468 
470  };
471 
472  auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) {
473  block_sync_lds();
474 
475  static_for<0, KRepeat, 1>{}([&](auto k) {
476  static_for<0, MRepeat, 1>{}([&](auto m0) {
479  a_block_buf.At(lds_read_buf),
481  make_tuple(m0, I0, k, I0),
482  a_thread_bufs(lds_read_reg_buf));
483  });
484  static_for<0, NRepeat, 1>{}([&](auto n0) {
487  b_block_buf.At(lds_read_buf),
489  make_tuple(n0, I0, k, I0),
490  b_thread_bufs(lds_read_reg_buf));
491  });
492  });
493 
494  static_for<0, KRepeat, 1>{}([&](auto k0) {
495  static_for<0, MRepeat, 1>{}([&](auto m0) {
496  static_for<0, NRepeat, 1>{}([&](auto n0) {
499 
500  static_for<0, KPack, 1>{}([&](auto ik) {
501  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
502  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
503  make_tuple(m0, I0, k0, ik))>{}];
504  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
505  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
506  make_tuple(n0, I0, k0, ik))>{}];
507  });
508 
509  using mfma_input_type =
510  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
511 
512  constexpr index_t c_offset =
513  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
514 
515  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
516  b_thread_vec.template AsType<mfma_input_type>(),
517  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
518  });
519  });
520  });
521 
523  };
524 
525  auto CompFunc = [&](auto mfma_reg_buf) {
526  static_for<0, KRepeat, 1>{}([&](auto k0) {
527  static_for<0, MRepeat, 1>{}([&](auto m0) {
528  static_for<0, NRepeat, 1>{}([&](auto n0) {
531 
532  static_for<0, KPack, 1>{}([&](auto ik) {
533  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
534  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
535  make_tuple(m0, I0, k0, ik))>{}];
536  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
537  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
538  make_tuple(n0, I0, k0, ik))>{}];
539  });
540 
541  using mfma_input_type =
542  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
543 
544  constexpr index_t c_offset =
545  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
546 
547  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
548  b_thread_vec.template AsType<mfma_input_type>(),
549  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
550  });
551  });
552  });
553  };
554  // tail
555  if constexpr(TailNum == TailNumber::Odd)
556  {
557  ReadWriteCompFunc(I1, I1, I0, I0);
558  ReadCompFunc(I0, I0, I1);
559  CompFunc(I0);
560  }
561  else if constexpr(TailNum == TailNumber::Even)
562  {
563  ReadCompFunc(I1, I1, I0);
564  CompFunc(I1);
565  }
566  }
567 
568  protected:
569  using Base::a_thread_copy_;
570  using Base::a_thread_desc_;
571  using Base::b_thread_copy_;
572  using Base::b_thread_desc_;
573  using Base::c_thread_desc_;
574 };
575 
576 } // namespace ck
Definition: ck.hpp:268
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
static constexpr index_t A_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:82
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
static constexpr index_t B_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:83
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:254
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops.hpp:105
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:961
static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:373
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:967
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:991
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops.hpp:104
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:453
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:990
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:454
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:955
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops.hpp:120
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10