/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp Source File
blockwise_gemm_pipeline_xdlops_v5.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimized pipeline
11 // GlobalPrefetchStages: 3
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 2
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::A_K1;
122  using Base::B_K1;
123  using Base::I0;
124  using Base::I1;
125  using Base::KRepeat;
126  using Base::xdlops_gemm;
127  using typename Base::HotLoopInstList;
128 
129  using Base::CalculateCThreadOriginDataIndex;
130  using Base::CalculateCThreadOriginDataIndex8D;
131  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
132  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
133  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
134  using Base::GetCThreadBuffer;
135  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
136  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
137  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
138  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
139  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
140 
141  using Base::a_block_desc_m0_m1_m2_k;
142  using Base::b_block_desc_n0_n1_n2_k;
143 
144  using Base::AMmaKStride;
145  using Base::BMmaKStride;
146 
148 
149  static constexpr index_t PrefetchStages = 3;
150  static constexpr index_t PrefillStages = 1;
151  static constexpr index_t GlobalBufferNum = 2;
152  static constexpr index_t HotloopUnroll = 2;
153 
154  __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
155  {
156  return num_loop > PrefetchStages;
157  }
158 
159  __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
160  {
161  if(num_loop % HotloopUnroll == 1)
162  {
163  return TailNumber::Odd;
164  }
165  else
166  {
167  return TailNumber::Even;
168  }
169  }
170 
171  __device__ static constexpr auto HotLoopScheduler()
172  {
173  // TODO: Take data type into consideration as pipe ver 3
174  // A/B split schedule
175  // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
176  constexpr auto num_ds_read_inst_a =
177  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
178  ? HotLoopInstList::A_LDS_Read_Inst_Num
179  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
180  constexpr auto num_ds_read_inst_b =
181  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
182  ? HotLoopInstList::B_LDS_Read_Inst_Num
183  : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
184 
185  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
186  constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
187 
188  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
189  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
190 
191  constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
192 
193  constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
194  constexpr auto ds_read_a_issue_cycle =
195  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
196  constexpr auto ds_read_b_issue_cycle =
197  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
198  constexpr auto ds_read_a_mfma_rate =
199  (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
200  constexpr auto ds_read_b_mfma_rate =
201  (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
202 
203  constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
204  constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
205  constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat;
206  constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat;
207 
208  constexpr auto num_dsread_stage1_a_mfma =
209  (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
210  constexpr auto num_dsread_stage1_b_mfma =
211  (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
212  constexpr auto num_dsread_stage3_a_mfma =
213  (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
214  constexpr auto num_dsread_stage3_b_mfma =
215  (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
216 
217  constexpr auto num_mfma_stage2 = num_mfma_inst - num_ds_read_inst_a / ds_read_a_mfma_rate -
218  num_ds_read_inst_b / ds_read_b_mfma_rate;
219  constexpr auto num_mfma_per_issue =
220  num_mfma_stage2 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
221  constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
222  constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
223 
224  // stage 1
226  ignore = i;
227  if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >=
228  ds_read_a_mfma_rate)
229  {
230  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
231  }
232  else
233  {
234  __builtin_amdgcn_sched_group_barrier(
235  0x100,
236  num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
237  0); // DS read
238  }
239  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
240  });
242  ignore = i;
243  if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >=
244  ds_read_b_mfma_rate)
245  {
246  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
247  }
248  else
249  {
250  __builtin_amdgcn_sched_group_barrier(
251  0x100,
252  num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
253  0); // DS read
254  }
255  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
256  });
257 
258  // stage 2
260  ignore = i;
261  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
262  ignore = idswrite;
263  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
264  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
265  });
266  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
267  __builtin_amdgcn_sched_group_barrier(
268  0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
269  });
271  ignore = i;
272  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
273  ignore = idswrite;
274  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
275  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
276  });
277  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
278  __builtin_amdgcn_sched_group_barrier(
279  0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
280  });
281 
282  // stage 3
284  ignore = i;
285  if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >=
286  ds_read_a_mfma_rate)
287  {
288  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
289  }
290  else
291  {
292  __builtin_amdgcn_sched_group_barrier(
293  0x100,
294  num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
295  0); // DS read
296  }
297  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
298  });
300  ignore = i;
301  if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >=
302  ds_read_b_mfma_rate)
303  {
304  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
305  }
306  else
307  {
308  __builtin_amdgcn_sched_group_barrier(
309  0x100,
310  num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
311  0); // DS read
312  }
313  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
314  });
315 
316  // IGLP COMPILER BUG:
317  // If comment out following scheduler barrier would cause sanity fail.
318  __builtin_amdgcn_sched_barrier(0);
319  }
320 
321  template <bool HasMainLoop,
322  TailNumber TailNum,
323  typename AGridDesc,
324  typename ABlockDesc,
325  typename ABlockTransfer,
326  typename AGridBuffer,
327  typename ABlockBuffer,
328  typename ABlockTransferStep,
329  typename BGridDesc,
330  typename BBlockDesc,
331  typename BBlockTransfer,
332  typename BGridBuffer,
333  typename BBlockBuffer,
334  typename BBlockTransferStep,
335  typename CThreadBuffer>
336  __device__ void Run(const AGridDesc& a_grid_desc,
337  const ABlockDesc& a_block_desc,
338  ABlockTransfer& a_blockwise_copy,
339  const AGridBuffer& a_grid_buf,
340  ABlockBuffer& a_block_buf,
341  const ABlockTransferStep& a_block_copy_step,
342  const BGridDesc& b_grid_desc,
343  const BBlockDesc& b_block_desc,
344  BBlockTransfer& b_blockwise_copy,
345  const BGridBuffer& b_grid_buf,
346  BBlockBuffer& b_block_buf,
347  const BBlockTransferStep& b_block_copy_step,
348  CThreadBuffer& c_thread_buf,
349  index_t num_loop) const
350  {
351  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
352  a_thread_desc_.GetElementSpaceSize());
353  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
354  b_thread_desc_.GetElementSpaceSize());
355 
356  // Global prefetch 1
357  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
358  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
359 
360  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
361  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
362 
363  // Local prefill 1
364  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
365  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
366 
367  // Global prefetch 2
368  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
369  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
370 
371  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
372  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
373 
374  // Global prefetch 3
375  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
376  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
377 
378  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
379  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
380 
381  // Initialize C
382  c_thread_buf.Clear();
383 
384  // Local prefetch 1
385  block_sync_lds();
386  static_for<0, MRepeat, 1>{}([&](auto m0) {
387  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
388  make_tuple(m0, I0, I0, I0),
389  a_block_buf,
390  a_thread_desc_,
391  make_tuple(m0, I0, I0, I0),
392  a_thread_buf);
393  });
394  static_for<0, NRepeat, 1>{}([&](auto n0) {
395  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
396  make_tuple(n0, I0, I0, I0),
397  b_block_buf,
398  b_thread_desc_,
399  make_tuple(n0, I0, I0, I0),
400  b_thread_buf);
401  });
402 
403  // main body
404  if constexpr(HasMainLoop)
405  {
406  index_t i = 0;
407  do
408  {
409  auto LoopFunc = [&](auto vmem_buf) {
412 
413  static_for<0, KRepeat, 1>{}([&](auto k0) {
414  if constexpr(k0 == (KRepeat - 1))
415  {
416  block_sync_lds();
417 
418  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
419  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
420 
421  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf);
422  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf);
423 
424  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
425  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
426 
427  block_sync_lds();
428  }
429  static_for<0, MRepeat, 1>{}([&](auto m0) {
430  static_for<0, NRepeat, 1>{}([&](auto n0) {
431  static_for<0, KPack, 1>{}([&](auto ik) {
432  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
433  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
434  make_tuple(m0, I0, I0, ik))>{}];
435  });
436  static_for<0, KPack, 1>{}([&](auto ik) {
437  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
438  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
439  make_tuple(n0, I0, I0, ik))>{}];
440  });
441 
442  using mfma_input_type =
444  xdlops_gemm.K1PerXdlops>::type;
445 
446  constexpr index_t c_offset =
447  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
448 
449  xdlops_gemm.Run(
450  a_thread_vec.template AsType<mfma_input_type>(),
451  b_thread_vec.template AsType<mfma_input_type>(),
452  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
453  });
454 
455  a_thread_copy_.Run(
456  a_block_desc_m0_m1_m2_k,
457  make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
458  a_block_buf,
459  a_thread_desc_,
460  make_tuple(m0, I0, I0, I0),
461  a_thread_buf);
462  });
463 
464  static_for<0, NRepeat, 1>{}([&](auto n0) {
465  b_thread_copy_.Run(
466  b_block_desc_n0_n1_n2_k,
467  make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
468  b_block_buf,
469  b_thread_desc_,
470  make_tuple(n0, I0, I0, I0),
471  b_thread_buf);
472  });
473  });
474 
475  HotLoopScheduler();
476  };
477 
478  LoopFunc(I0);
479  LoopFunc(I1);
480 
481  i += HotloopUnroll;
482  } while(i < (num_loop - PrefetchStages));
483  }
484  // tail
485  auto ReadWriteCompFunc = [&](auto vmem_buf) {
488 
489  static_for<0, KRepeat, 1>{}([&](auto k0) {
490  if constexpr(k0 == (KRepeat - 1))
491  {
492  block_sync_lds();
493 
494  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
495  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
496 
497  block_sync_lds();
498  }
499  static_for<0, MRepeat, 1>{}([&](auto m0) {
500  static_for<0, NRepeat, 1>{}([&](auto n0) {
501  static_for<0, KPack, 1>{}([&](auto ik) {
502  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
503  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
504  make_tuple(m0, I0, I0, ik))>{}];
505  });
506  static_for<0, KPack, 1>{}([&](auto ik) {
507  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
508  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
509  make_tuple(n0, I0, I0, ik))>{}];
510  });
511 
512  using mfma_input_type =
513  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
514 
515  constexpr index_t c_offset =
516  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
517 
518  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
519  b_thread_vec.template AsType<mfma_input_type>(),
520  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
521  });
522  a_thread_copy_.Run(
523  a_block_desc_m0_m1_m2_k,
524  make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
525  a_block_buf,
526  a_thread_desc_,
527  make_tuple(m0, I0, I0, I0),
528  a_thread_buf);
529  });
530 
531  static_for<0, NRepeat, 1>{}([&](auto n0) {
532  b_thread_copy_.Run(
533  b_block_desc_n0_n1_n2_k,
534  make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
535  b_block_buf,
536  b_thread_desc_,
537  make_tuple(n0, I0, I0, I0),
538  b_thread_buf);
539  });
540  });
541 
542  HotLoopScheduler();
543  };
544  auto ReadCompFunc = [&]() {
547 
548  static_for<0, KRepeat - 1, 1>{}([&](auto k0) {
549  static_for<0, MRepeat, 1>{}([&](auto m0) {
550  static_for<0, NRepeat, 1>{}([&](auto n0) {
551  static_for<0, KPack, 1>{}([&](auto ik) {
552  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
553  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
554  make_tuple(m0, I0, I0, ik))>{}];
555  });
556  static_for<0, KPack, 1>{}([&](auto ik) {
557  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
558  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
559  make_tuple(n0, I0, I0, ik))>{}];
560  });
561 
562  using mfma_input_type =
563  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
564 
565  constexpr index_t c_offset =
566  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
567 
568  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
569  b_thread_vec.template AsType<mfma_input_type>(),
570  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
571  });
572 
573  a_thread_copy_.Run(
574  a_block_desc_m0_m1_m2_k,
575  make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
576  a_block_buf,
577  a_thread_desc_,
578  make_tuple(m0, I0, I0, I0),
579  a_thread_buf);
580  });
581 
582  static_for<0, NRepeat, 1>{}([&](auto n0) {
583  b_thread_copy_.Run(
584  b_block_desc_n0_n1_n2_k,
585  make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
586  b_block_buf,
587  b_thread_desc_,
588  make_tuple(n0, I0, I0, I0),
589  b_thread_buf);
590  });
591  });
592 
593  static_for<0, MRepeat, 1>{}([&](auto m0) {
594  static_for<0, NRepeat, 1>{}([&](auto n0) {
595  static_for<0, KPack, 1>{}([&](auto ik) {
596  a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
597  [Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, I0, ik))>{}];
598  });
599  static_for<0, KPack, 1>{}([&](auto ik) {
600  b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
601  [Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, I0, ik))>{}];
602  });
603 
604  using mfma_input_type =
605  typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
606 
607  constexpr index_t c_offset =
608  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
609 
610  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
611  b_thread_vec.template AsType<mfma_input_type>(),
612  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
613  });
614  });
615 
616  HotLoopScheduler();
617  };
618 
619  if constexpr(TailNum == TailNumber::Odd)
620  {
621  ReadWriteCompFunc(I0);
622  ReadWriteCompFunc(I1);
623  ReadCompFunc();
624  }
625  else if constexpr(TailNum == TailNumber::Even)
626  {
627  ReadWriteCompFunc(I0);
628  ReadCompFunc();
629  }
630  }
631 
632  protected:
633  // A[MRepeat, I1, I1, KPack]
634  static constexpr auto a_thread_desc_ =
636 
637  // B[NRepeat, N1, N2, KPack]
638  static constexpr auto b_thread_desc_ =
640 
643  decltype(a_block_desc_m0_m1_m2_k),
644  decltype(a_thread_desc_),
647  3,
648  A_K1,
649  A_K1>;
650 
653  decltype(b_block_desc_n0_n1_n2_k),
654  decltype(b_thread_desc_),
657  3,
658  B_K1,
659  B_K1>;
660 
661  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
662  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
663  using Base::c_thread_desc_;
664 };
665 
666 } // namespace ck
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:336
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:37
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10