/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp Source File
blockwise_gemm_pipeline_xdlops_v4.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimimal pipeline with highest resource request
11 // GlobalPrefetchStages: 3
12 // LocalPreFillStages: 2
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 2
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
36 struct BlockwiseGemmXdlops_pipeline_v4
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::I0;
122  using Base::I1;
123  using Base::KRepeat;
124  using Base::xdlops_gemm;
125  using typename Base::HotLoopInstList;
126 
127  using Base::CalculateCThreadOriginDataIndex;
128  using Base::CalculateCThreadOriginDataIndex8D;
129  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
130  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
131  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
132  using Base::GetCThreadBuffer;
133  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
134  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
136  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
137  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
138 
139  using Base::a_block_desc_m0_m1_m2_k;
140  using Base::b_block_desc_n0_n1_n2_k;
141 
142  using Base::AMmaKStride;
143  using Base::BMmaKStride;
144 
145  static constexpr index_t PrefetchStages = 3;
146  static constexpr index_t PrefillStages = 2;
147  static constexpr index_t GlobalBufferNum = 1;
148  static constexpr index_t HotloopUnroll = 2;
149 
150  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
151  {
152  return num_loop > PrefetchStages;
153  }
154 
155  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
156  {
157  if(num_loop % HotloopUnroll == 1)
158  {
159  return TailNumber::Odd;
160  }
161  else
162  {
163  return TailNumber::Even;
164  }
165  }
166 
167  __device__ static constexpr void HotLoopScheduler()
168  {
169  // TODO: Take data type into consideration as pipe ver 3
170  // A-B splited schedule
171  constexpr auto num_ds_read_inst_a =
172  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
175  constexpr auto num_ds_read_inst_b =
176  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
179 
180  constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
181  constexpr auto num_dswrite_per_issue_a =
182  (HotLoopInstList::A_LDS_Write_Inst_Num + num_issue_a - 1) / num_issue_a;
183  constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
184 
185  constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
186  constexpr auto num_dswrite_per_issue_b =
187  (HotLoopInstList::B_LDS_Write_Inst_Num + num_issue_b - 1) / num_issue_b;
188  constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
189 
190  constexpr auto num_mfma_per_issue =
191  HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b);
192 
193  static_for<0, num_issue_a, 1>{}([&](auto i) {
194  ignore = i;
195  static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) {
196  ignore = idsread;
197  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
198  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
199  });
200 
201  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
202  ignore = idswrite;
203  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
204  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
205  });
206 
207  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
208  __builtin_amdgcn_sched_group_barrier(0x008,
209  num_mfma_per_issue - num_dsread_per_issue_a -
210  num_dswrite_per_issue_a,
211  0); // MFMA
212  });
213 
214  static_for<0, num_issue_b, 1>{}([&](auto i) {
215  ignore = i;
216  static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) {
217  ignore = idsread;
218  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
219  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
220  });
221 
222  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
223  ignore = idswrite;
224  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
225  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
226  });
227 
228  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
229  __builtin_amdgcn_sched_group_barrier(0x008,
230  num_mfma_per_issue - num_dsread_per_issue_a -
231  num_dswrite_per_issue_b,
232  0); // MFMA
233  });
234  __builtin_amdgcn_sched_barrier(0);
235  }
236 
237  template <bool HasMainLoop,
238  TailNumber TailNum,
239  typename AGridDesc,
240  typename ABlockDesc,
241  typename ABlockTransfer,
242  typename AGridBuffer,
243  typename ABlockBuffer,
244  typename ABlockTransferStep,
245  typename BGridDesc,
246  typename BBlockDesc,
247  typename BBlockTransfer,
248  typename BGridBuffer,
249  typename BBlockBuffer,
250  typename BBlockTransferStep,
251  typename CThreadBuffer>
252  __device__ void Run(const AGridDesc& a_grid_desc,
253  const ABlockDesc& a_block_desc,
254  ABlockTransfer& a_blockwise_copy,
255  const AGridBuffer& a_grid_buf,
256  ABlockBuffer& a_block_buf,
257  const ABlockTransferStep& a_block_copy_step,
258  const BGridDesc& b_grid_desc,
259  const BBlockDesc& b_block_desc,
260  BBlockTransfer& b_blockwise_copy,
261  const BGridBuffer& b_grid_buf,
262  BBlockBuffer& b_block_buf,
263  const BBlockTransferStep& b_block_copy_step,
264  CThreadBuffer& c_thread_buf,
265  index_t num_loop) const
266  {
267  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
268  a_thread_desc_.GetElementSpaceSize());
269  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
270  b_thread_desc_.GetElementSpaceSize());
271 
272  StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
273  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
274 
275  // Global prefetch 1
276  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
277  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
278 
279  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
280  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
281 
282  // Local prefill 1
283  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
284  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
285 
286  // Local prefetch 1
287  block_sync_lds();
288  static_for<0, KRepeat, 1>{}([&](auto k) {
289  static_for<0, MRepeat, 1>{}([&](auto m0) {
292  a_block_buf.At(I0),
294  make_tuple(m0, I0, k, I0),
295  a_thread_bufs(I0));
296  });
297  static_for<0, NRepeat, 1>{}([&](auto n0) {
300  b_block_buf.At(I0),
302  make_tuple(n0, I0, k, I0),
303  b_thread_bufs(I0));
304  });
305  });
306 
307  // Global prefetch 2
308  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
309  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
310 
311  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
312  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
313 
314  // Local prefill 2
315  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
316  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
317 
318  // Global prefetch 3
319  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
320  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
321 
322  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
323  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
324 
325  // Initialize C
326  c_thread_buf.Clear();
327 
328  // main body
329  if constexpr(HasMainLoop)
330  {
331  index_t i = 0;
332  // This hot loop has two legacy loopover, to implement the double local buffer strategy
333  do
334  {
335  auto LoopFunc = [&](auto lds_read_buf,
336  auto lds_read_reg_buf,
337  auto lds_write_buf,
338  auto mfma_reg_buf) {
339  block_sync_lds();
340 
341  static_for<0, KRepeat, 1>{}([&](auto k) {
342  static_for<0, MRepeat, 1>{}([&](auto m0) {
345  a_block_buf.At(lds_read_buf),
347  make_tuple(m0, I0, k, I0),
348  a_thread_bufs(lds_read_reg_buf));
349  });
350  static_for<0, NRepeat, 1>{}([&](auto n0) {
353  b_block_buf.At(lds_read_buf),
355  make_tuple(n0, I0, k, I0),
356  b_thread_bufs(lds_read_reg_buf));
357  });
358  });
359 
360  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
361  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
362 
363  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
364  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
365 
366  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
367  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
368 
369  static_for<0, KRepeat, 1>{}([&](auto k0) {
370  static_for<0, MRepeat, 1>{}([&](auto m0) {
371  static_for<0, NRepeat, 1>{}([&](auto n0) {
374 
375  static_for<0, KPack, 1>{}([&](auto ik) {
376  a_thread_vec.template AsType<ComputeDataType>()(ik) =
377  a_thread_bufs[mfma_reg_buf]
378  [Number<a_thread_desc_.CalculateOffset(
379  make_tuple(m0, I0, k0, ik))>{}];
380  b_thread_vec.template AsType<ComputeDataType>()(ik) =
381  b_thread_bufs[mfma_reg_buf]
382  [Number<b_thread_desc_.CalculateOffset(
383  make_tuple(n0, I0, k0, ik))>{}];
384  });
385 
386  using mfma_input_type =
387  typename vector_type<ComputeDataType,
388  xdlops_gemm.K1PerXdlops>::type;
389 
390  constexpr index_t c_offset =
391  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
392 
393  xdlops_gemm.Run(
394  a_thread_vec.template AsType<mfma_input_type>(),
395  b_thread_vec.template AsType<mfma_input_type>(),
396  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
397  });
398  });
399  });
400 
402  };
403 
404  LoopFunc(I1, I1, I0, I0);
405  LoopFunc(I0, I0, I1, I1);
406 
407  i += HotloopUnroll;
408  } while(i < (num_loop - PrefetchStages));
409  }
410 
411  auto ReadWriteCompFunc = [&](auto lds_read_buf,
412  auto lds_read_reg_buf,
413  auto lds_write_buf,
414  auto mfma_reg_buf) {
415  block_sync_lds();
416 
417  static_for<0, KRepeat, 1>{}([&](auto k) {
418  static_for<0, MRepeat, 1>{}([&](auto m0) {
421  a_block_buf.At(lds_read_buf),
423  make_tuple(m0, I0, k, I0),
424  a_thread_bufs(lds_read_reg_buf));
425  });
426  static_for<0, NRepeat, 1>{}([&](auto n0) {
429  b_block_buf.At(lds_read_buf),
431  make_tuple(n0, I0, k, I0),
432  b_thread_bufs(lds_read_reg_buf));
433  });
434  });
435 
436  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
437  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
438 
439  static_for<0, KRepeat, 1>{}([&](auto k0) {
440  static_for<0, MRepeat, 1>{}([&](auto m0) {
441  static_for<0, NRepeat, 1>{}([&](auto n0) {
444 
445  static_for<0, KPack, 1>{}([&](auto ik) {
446  a_thread_vec.template AsType<ComputeDataType>()(ik) =
447  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
448  make_tuple(m0, I0, k0, ik))>{}];
449  b_thread_vec.template AsType<ComputeDataType>()(ik) =
450  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
451  make_tuple(n0, I0, k0, ik))>{}];
452  });
453 
454  using mfma_input_type =
455  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
456 
457  constexpr index_t c_offset =
458  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
459 
460  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
461  b_thread_vec.template AsType<mfma_input_type>(),
462  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
463  });
464  });
465  });
466 
468  };
469 
470  auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) {
471  block_sync_lds();
472 
473  static_for<0, KRepeat, 1>{}([&](auto k) {
474  static_for<0, MRepeat, 1>{}([&](auto m0) {
477  a_block_buf.At(lds_read_buf),
479  make_tuple(m0, I0, k, I0),
480  a_thread_bufs(lds_read_reg_buf));
481  });
482  static_for<0, NRepeat, 1>{}([&](auto n0) {
485  b_block_buf.At(lds_read_buf),
487  make_tuple(n0, I0, k, I0),
488  b_thread_bufs(lds_read_reg_buf));
489  });
490  });
491 
492  static_for<0, KRepeat, 1>{}([&](auto k0) {
493  static_for<0, MRepeat, 1>{}([&](auto m0) {
494  static_for<0, NRepeat, 1>{}([&](auto n0) {
497 
498  static_for<0, KPack, 1>{}([&](auto ik) {
499  a_thread_vec.template AsType<ComputeDataType>()(ik) =
500  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
501  make_tuple(m0, I0, k0, ik))>{}];
502  b_thread_vec.template AsType<ComputeDataType>()(ik) =
503  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
504  make_tuple(n0, I0, k0, ik))>{}];
505  });
506 
507  using mfma_input_type =
508  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
509 
510  constexpr index_t c_offset =
511  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
512 
513  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
514  b_thread_vec.template AsType<mfma_input_type>(),
515  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
516  });
517  });
518  });
519 
521  };
522 
523  auto CompFunc = [&](auto mfma_reg_buf) {
524  static_for<0, KRepeat, 1>{}([&](auto k0) {
525  static_for<0, MRepeat, 1>{}([&](auto m0) {
526  static_for<0, NRepeat, 1>{}([&](auto n0) {
529 
530  static_for<0, KPack, 1>{}([&](auto ik) {
531  a_thread_vec.template AsType<ComputeDataType>()(ik) =
532  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
533  make_tuple(m0, I0, k0, ik))>{}];
534  b_thread_vec.template AsType<ComputeDataType>()(ik) =
535  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
536  make_tuple(n0, I0, k0, ik))>{}];
537  });
538 
539  using mfma_input_type =
540  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
541 
542  constexpr index_t c_offset =
543  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
544 
545  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
546  b_thread_vec.template AsType<mfma_input_type>(),
547  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
548  });
549  });
550  });
551  };
552  // tail
553  if constexpr(TailNum == TailNumber::Odd)
554  {
555  ReadWriteCompFunc(I1, I1, I0, I0);
556  ReadCompFunc(I0, I0, I1);
557  CompFunc(I0);
558  }
559  else if constexpr(TailNum == TailNumber::Even)
560  {
561  ReadCompFunc(I1, I1, I0);
562  CompFunc(I1);
563  }
564  }
565 
566  protected:
567  using Base::a_thread_copy_;
568  using Base::a_thread_desc_;
569  using Base::b_thread_copy_;
570  using Base::b_thread_desc_;
571  using Base::c_thread_desc_;
572 };
573 
574 } // namespace ck
Definition: ck.hpp:267
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
static constexpr index_t A_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:82
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
static constexpr index_t B_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:83
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:252
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops.hpp:105
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:963
static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:375
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:969
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:993
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops.hpp:104
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:455
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:992
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:456
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:957
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops.hpp:122
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10