/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp Source File
blockwise_gemm_pipeline_xdlops.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 // Double LDS buffer
13 // Prefetech 2 stage
14 // Local prefetch 1 stage
15 
16 namespace ck {
17 
18 template <index_t BlockSize,
19  index_t MPerBlock,
20  index_t NPerBlock,
21  index_t KPerBlock,
22  index_t ABufferLoadWidth,
23  index_t BBufferLoadWidth,
24  index_t ALDSWriteWidth,
25  index_t BLDSWriteWidth,
26  index_t ALDSReadWidth,
27  index_t BLDSReadWidth,
28  index_t MRepeat,
29  index_t NRepeat,
30  index_t MPerXDL,
31  index_t NPerXDL,
32  index_t KPerXDL>
34 {
35  static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
36  static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
37  static constexpr index_t WaveSize = BlockSize / (WaveNumM * WaveNumN);
38 
39  static constexpr index_t A_Buffer_Load_Inst_Num =
40  MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
41  static constexpr index_t B_Buffer_Load_Inst_Num =
42  NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
43 
44  static constexpr index_t A_LDS_Write_Inst_Num =
45  MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
46  static constexpr index_t B_LDS_Write_Inst_Num =
47  NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
48 
49  static constexpr index_t A_LDS_Read_Inst_Num =
50  WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
51  static constexpr index_t B_LDS_Read_Inst_Num =
52  WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
53 
54  static constexpr index_t C_MFMA_Inst_Num =
55  MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
56 
57  static constexpr auto Print()
58  {
59  printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
60  BlockSize,
61  WaveSize,
62  MPerBlock,
63  NPerBlock,
64  KPerBlock,
65  MPerXDL,
66  NPerXDL,
67  KPerXDL);
68 
69  printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
70  "%d, %d\n C MFMA inst: %d\n",
78  }
79 };
80 
81 template <
82  index_t BlockSize,
83  typename FloatAB,
84  typename FloatAcc,
85  typename ATileDesc,
86  typename BTileDesc,
87  typename AMmaTileDesc,
88  typename BMmaTileDesc,
89  index_t MPerBlock,
90  index_t NPerBlock,
91  index_t KPerBlock,
92  index_t MPerXDL,
93  index_t NPerXDL,
94  index_t MRepeat,
95  index_t NRepeat,
96  index_t KPack,
97  bool TransposeC = false,
98  index_t AMmaKStride =
99  KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
100  index_t BMmaKStride =
101  KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
103 {
104  static constexpr auto I0 = Number<0>{};
105  static constexpr auto I1 = Number<1>{};
106  static constexpr auto I2 = Number<2>{};
107  static constexpr auto I3 = Number<3>{};
108 
110 
111  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
112  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
113  static_assert(MWaves > 0);
114  static_assert(NWaves > 0);
115  static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
116 
117  static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
118  static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
119  static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
120  static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
121 
122  static constexpr auto xdlops_gemm =
124 
125  static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
126  static constexpr index_t KRepeat = KPerThread / KPack;
127 
129  MPerBlock,
130  NPerBlock,
131  KPerBlock,
132  A_K1,
133  B_K1,
134  A_K1,
135  B_K1,
136  KPack,
137  KPack,
138  MRepeat,
139  NRepeat,
140  MPerXDL,
141  NPerXDL,
142  xdlops_gemm.KPerXdlops>;
143 
144  static_assert(KPerThread % KPack == 0,
145  "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
146 
148  FloatAcc,
149  MRepeat * NRepeat,
150  xdlops_gemm.GetRegSizePerXdlops(),
151  true>
153 
154  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
155 
156  __device__ static auto GetWaveIdx()
157  {
158  const index_t thread_id = ThisThreadBlock::GetThreadId();
159 
160  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
164 
165  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
166  }
167 
168  __device__ static auto CalculateAThreadOriginDataIndex()
169  {
170  const auto wave_idx = GetWaveIdx();
171 
172  const auto waveId_m = wave_idx[I0];
173 
174  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
175 
176  return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
177  }
178 
179  __device__ static auto CalculateBThreadOriginDataIndex()
180  {
181  const auto wave_idx = GetWaveIdx();
182 
183  const auto waveId_n = wave_idx[I1];
184 
185  const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
186 
187  return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
188  }
189 
190  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
191  __device__ static auto
193  {
194  const auto wave_idx = GetWaveIdx();
195 
196  const auto waveId_m = wave_idx[I0];
197  const auto waveId_n = wave_idx[I1];
198 
199  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
200 
201  constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
202  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
205 
206  constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
207  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
210 
211  const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
212  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
213  const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
214  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
215 
216  return make_tuple(c_thread_m, c_thread_n);
217  }
218 
219  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
220  __device__ static auto
222  {
223  const auto wave_idx = GetWaveIdx();
224 
225  const auto waveId_m = wave_idx[I0];
226  const auto waveId_n = wave_idx[I1];
227 
228  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
229 
230  return make_tuple(
231  m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
232  }
233 
235 
236  __host__ __device__
239  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
240  {
241 #if defined(__HIP_DEVICE_COMPILE__)
242  static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
243  "wrong! Desc should be known at compile-time");
244 
245  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
246  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
247 
248  static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
249  "wrong!");
250 #endif
251  // HotLoopInstList::Print();
252  }
253 
254  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
255  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
256  {
257  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
258 
259  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
260  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
261  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
262  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
263 
265  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
266  }
267 
268  // XDL output supporting C_xdl = A_xdl * B_xdl
269  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
270  {
271  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
272 
273  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
274  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
275  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
276  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
277 
279  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
280  }
281 
282  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
283  {
284  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
285 
286  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
287  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
288  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
289  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
290 
292  make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
293  }
294 
295  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
296  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
297  {
298  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
300  Number<NRepeat>{},
301  Number<MWaves>{},
302  Number<NWaves>{},
303  Number<MPerXDL>{},
304  Number<NPerXDL>{}));
305 
306  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
307  }
308 
309  // XDL output supporting C_xdl = A_xdl * B_xdl
310  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
311  {
312  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
314  Number<NRepeat>{},
315  Number<MWaves>{},
316  Number<NWaves>{},
317  Number<MPerXDL>{},
318  Number<NPerXDL>{}));
319 
320  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
321  }
322 
323  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
324  {
325  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
327  Number<MRepeat>{},
328  Number<NRepeat>{},
329  Number<MWaves>{},
330  Number<NWaves>{},
331  Number<MPerXDL>{},
332  Number<NPerXDL>{}));
333 
334  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
335  c_block_desc_g_m0_n0_m1_n1_m2_n2);
336  }
337 
338  template <typename CGridDesc_M_N>
339  __host__ __device__ static constexpr auto
340  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
341  {
342  const auto M = c_grid_desc_m_n.GetLength(I0);
343  const auto N = c_grid_desc_m_n.GetLength(I1);
344 
345  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
346  c_grid_desc_m_n,
347  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
348  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
351 
352  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
353  }
354 
355  template <typename CGridDesc_G_M_N>
356  __host__ __device__ static constexpr auto
357  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
358  {
359  const auto G = c_grid_desc_g_m_n.GetLength(I0);
360  const auto M = c_grid_desc_g_m_n.GetLength(I1);
361  const auto N = c_grid_desc_g_m_n.GetLength(I2);
362 
363  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
364  c_grid_desc_g_m_n,
366  make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
367  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
370 
371  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
372  c_grid_desc_g_m0_n0_m1_n1_m2_n2);
373  }
374 
375  __device__ static constexpr auto HotLoopScheduler()
376  {
377  // schedule
378  constexpr auto num_ds_read_inst =
380  constexpr auto num_ds_write_inst =
382  ;
383  constexpr auto num_buffer_load_inst =
385  ;
386  constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
387 
388  constexpr auto num_issue = num_buffer_load_inst;
389 
390  static_for<0, num_issue, 1>{}([&](auto i) {
391  ignore = i;
392  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
393  __builtin_amdgcn_sched_group_barrier(
394  0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read
395  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
396  __builtin_amdgcn_sched_group_barrier(
397  0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write
398  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
399  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
400  __builtin_amdgcn_sched_group_barrier(
401  0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA
402  });
403  }
404 
405  template <index_t stage>
406  __device__ static constexpr auto TailScheduler()
407  {
408  }
409 
410  template <>
411  __device__ constexpr auto TailScheduler<1>()
412  {
413  // schedule
414  constexpr auto num_ds_read_inst =
416  constexpr auto num_ds_write_inst =
418  ;
419  constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
420 
421  constexpr auto num_issue = num_ds_write_inst;
422 
423  static_for<0, num_issue, 1>{}([&](auto i) {
424  ignore = i;
425  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
426  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
427  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
428  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
429  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
430  __builtin_amdgcn_sched_group_barrier(
431  0x100, num_ds_read_inst / num_ds_write_inst - 1, 0); // DS read
432  __builtin_amdgcn_sched_group_barrier(
433  0x008, num_mfma_inst / num_ds_write_inst - 3, 0); // MFMA
434  });
435  }
436 
437  template <>
438  __device__ constexpr auto TailScheduler<2>()
439  {
440  // schedule
441  constexpr auto num_ds_read_inst =
443  constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
444 
445  constexpr auto num_issue = num_ds_read_inst;
446 
447  static_for<0, num_issue, 1>{}([&](auto i) {
448  ignore = i;
449  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
450  __builtin_amdgcn_sched_group_barrier(
451  0x008, num_mfma_inst / num_ds_read_inst, 0); // MFMA
452  });
453  }
454 
455  static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
456  static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
457 
458  template <bool HasMainLoop,
459  index_t TailNum,
460  typename AGridDesc,
461  typename ABlockDesc,
462  typename ABlockTransfer,
463  typename AGridBuffer,
464  typename ABlockBuffer,
465  typename ABlockTransferStep,
466  typename BGridDesc,
467  typename BBlockDesc,
468  typename BBlockTransfer,
469  typename BGridBuffer,
470  typename BBlockBuffer,
471  typename BBlockTransferStep,
472  typename CThreadBuffer>
473  __device__ void Run(const AGridDesc& a_grid_desc,
474  const ABlockDesc& a_block_desc,
475  ABlockTransfer& a_blockwise_copy,
476  const AGridBuffer& a_grid_buf,
477  ABlockBuffer& a_block_buf,
478  const ABlockTransferStep& a_block_copy_step,
479  const BGridDesc& b_grid_desc,
480  const BBlockDesc& b_block_desc,
481  BBlockTransfer& b_blockwise_copy,
482  const BGridBuffer& b_grid_buf,
483  BBlockBuffer& b_block_buf,
484  const BBlockTransferStep& b_block_copy_step,
485  CThreadBuffer& c_thread_buf,
486  index_t num_loop) const
487  {
488  __builtin_amdgcn_sched_barrier(0);
489  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
490  a_thread_desc_.GetElementSpaceSize());
491  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
492  b_thread_desc_.GetElementSpaceSize());
493 
494  StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
495  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
496  // Inst List:
497  // ds_read_b128: 16
498  // ds_write_b128: 8
499  // buffer_load_dwordx4: 16
500  // v_mfma: 0
501  // -------------------------------------------------------------------------------------------
502 
503  // Global prefetch 1th, Fill Ping LDS
504  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
505  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
506 
507  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
508  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
509 
510  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
511  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
512 
513  // Local prefetch 1th, Fill Ping Reg
514  block_sync_lds();
515  static_for<0, KRepeat, 1>{}([&](auto k) {
516  static_for<0, MRepeat, 1>{}([&](auto m0) {
519  a_block_buf.At(I0),
521  make_tuple(m0, I0, k, I0),
522  a_thread_bufs(I0));
523  static_for<0, NRepeat, 1>{}([&](auto n0) {
526  b_block_buf.At(I0),
528  make_tuple(n0, I0, k, I0),
529  b_thread_bufs(I0));
530  });
531  });
532  });
533 
534  // Global prefetch 2th, Fill Pong LDS
535  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
536  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
537 
538  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
539  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
540 
541  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
542  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
543 
544  // Global prefetch 3rd
545  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
546  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
547 
548  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
549  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
550 
551  // Initialize C
552  c_thread_buf.Clear();
553 
554  // main body
555  if constexpr(HasMainLoop)
556  {
557  index_t i = 0;
558  // This hot loop has two legacy loopover, to implement the double local buffer strategy
559  do
560  {
561  // -------------------------------------------------------------------------------------------
562  using PingP1 = Number<0>;
563  using PongP1 = Number<1>;
564  // MFMA: Ping Reg
565  // DS_WRITE: To Ping LDS
566  // DS_READ: Pong LDS to Pong Reg
567  block_sync_lds();
568 
569  static_for<0, KRepeat, 1>{}([&](auto k) {
570  static_for<0, MRepeat, 1>{}([&](auto m0) {
573  a_block_buf.At(PongP1{}),
575  make_tuple(m0, I0, k, I0),
576  a_thread_bufs(PongP1{}));
577  static_for<0, NRepeat, 1>{}([&](auto n0) {
580  b_block_buf.At(PongP1{}),
582  make_tuple(n0, I0, k, I0),
583  b_thread_bufs(PongP1{}));
584  });
585  });
586  });
587 
588  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
589  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
590 
591  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
592  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
593 
594  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
595  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
596 
597  static_for<0, KRepeat, 1>{}([&](auto k0) {
598  static_for<0, MRepeat, 1>{}([&](auto m0) {
599  static_for<0, NRepeat, 1>{}([&](auto n0) {
600  vector_type<FloatAB, KPack> a_thread_vec;
601  vector_type<FloatAB, KPack> b_thread_vec;
602 
603  static_for<0, KPack, 1>{}([&](auto ik) {
604  a_thread_vec.template AsType<FloatAB>()(ik) =
605  a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
606  make_tuple(m0, I0, k0, ik))>{}];
607  b_thread_vec.template AsType<FloatAB>()(ik) =
608  b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
609  make_tuple(n0, I0, k0, ik))>{}];
610  });
611 
612  using mfma_input_type =
613  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
614 
615  constexpr index_t c_offset =
616  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
617 
618  xdlops_gemm.Run(
619  a_thread_vec.template AsType<mfma_input_type>(),
620  b_thread_vec.template AsType<mfma_input_type>(),
621  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
622  });
623  });
624  });
625 
626  HotLoopScheduler();
627  __builtin_amdgcn_sched_barrier(0);
628 
629  // -------------------------------------------------------------------------------------------
630  using PingP2 = Number<1>;
631  using PongP2 = Number<0>;
632  // MFMA: Pong Reg
633  // DS_WRITE: To Pong LDS
634  // DS_READ: Ping LDS to Ping Reg
635  block_sync_lds();
636 
637  static_for<0, KRepeat, 1>{}([&](auto k) {
638  static_for<0, MRepeat, 1>{}([&](auto m0) {
639  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
640  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
641  a_block_buf.At(PongP2{}),
642  a_thread_desc_,
643  make_tuple(m0, I0, k, I0),
644  a_thread_bufs(PongP2{}));
645  static_for<0, NRepeat, 1>{}([&](auto n0) {
646  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
647  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
648  b_block_buf.At(PongP2{}),
649  b_thread_desc_,
650  make_tuple(n0, I0, k, I0),
651  b_thread_bufs(PongP2{}));
652  });
653  });
654  });
655 
656  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{}));
657  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{}));
658 
659  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
660  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
661 
662  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
663  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
664 
665  static_for<0, KRepeat, 1>{}([&](auto k0) {
666  static_for<0, MRepeat, 1>{}([&](auto m0) {
667  static_for<0, NRepeat, 1>{}([&](auto n0) {
668  vector_type<FloatAB, KPack> a_thread_vec;
669  vector_type<FloatAB, KPack> b_thread_vec;
670 
671  static_for<0, KPack, 1>{}([&](auto ik) {
672  a_thread_vec.template AsType<FloatAB>()(ik) =
673  a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
674  make_tuple(m0, I0, k0, ik))>{}];
675  b_thread_vec.template AsType<FloatAB>()(ik) =
676  b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
677  make_tuple(n0, I0, k0, ik))>{}];
678  });
679 
680  using mfma_input_type =
681  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
682 
683  constexpr index_t c_offset =
684  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
685 
686  xdlops_gemm.Run(
687  a_thread_vec.template AsType<mfma_input_type>(),
688  b_thread_vec.template AsType<mfma_input_type>(),
689  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
690  });
691  });
692  });
693 
694  HotLoopScheduler();
695  __builtin_amdgcn_sched_barrier(0);
696 
697  i += 2;
698  } while(i < (num_loop - 3));
699  }
700 
701  // tail
702  if constexpr(TailNum == 3)
703  {
704  using PingP1 = Number<0>;
705  using PongP1 = Number<1>;
706  // MFMA: Ping Reg
707  // DS_WRITE: To Ping LDS
708  // DS_READ: Pong LDS to Pong Reg
709  block_sync_lds();
710 
711  static_for<0, KRepeat, 1>{}([&](auto k) {
712  static_for<0, MRepeat, 1>{}([&](auto m0) {
713  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
714  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
715  a_block_buf.At(PongP1{}),
716  a_thread_desc_,
717  make_tuple(m0, I0, k, I0),
718  a_thread_bufs(PongP1{}));
719  static_for<0, NRepeat, 1>{}([&](auto n0) {
720  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
721  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
722  b_block_buf.At(PongP1{}),
723  b_thread_desc_,
724  make_tuple(n0, I0, k, I0),
725  b_thread_bufs(PongP1{}));
726  });
727  });
728  });
729 
730  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
731  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
732 
733  static_for<0, KRepeat, 1>{}([&](auto k0) {
734  static_for<0, MRepeat, 1>{}([&](auto m0) {
735  static_for<0, NRepeat, 1>{}([&](auto n0) {
736  vector_type<FloatAB, KPack> a_thread_vec;
737  vector_type<FloatAB, KPack> b_thread_vec;
738 
739  static_for<0, KPack, 1>{}([&](auto ik) {
740  a_thread_vec.template AsType<FloatAB>()(ik) =
741  a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
742  make_tuple(m0, I0, k0, ik))>{}];
743  b_thread_vec.template AsType<FloatAB>()(ik) =
744  b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
745  make_tuple(n0, I0, k0, ik))>{}];
746  });
747 
748  using mfma_input_type =
749  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
750 
751  constexpr index_t c_offset =
752  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
753 
754  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
755  b_thread_vec.template AsType<mfma_input_type>(),
756  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
757  });
758  });
759  });
760 
761  TailScheduler<1>();
762  __builtin_amdgcn_sched_barrier(0);
763 
764  // -------------------------------------------------------------------------------------------
765  using PingP2 = Number<1>;
766  using PongP2 = Number<0>;
767  // MFMA: Pong Reg
768  // DS_WRITE: To Pong LDS
769  // DS_READ: Ping LDS to Ping Reg
770  block_sync_lds();
771 
772  static_for<0, KRepeat, 1>{}([&](auto k) {
773  static_for<0, MRepeat, 1>{}([&](auto m0) {
774  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
775  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
776  a_block_buf.At(PongP2{}),
777  a_thread_desc_,
778  make_tuple(m0, I0, k, I0),
779  a_thread_bufs(PongP2{}));
780  static_for<0, NRepeat, 1>{}([&](auto n0) {
781  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
782  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
783  b_block_buf.At(PongP2{}),
784  b_thread_desc_,
785  make_tuple(n0, I0, k, I0),
786  b_thread_bufs(PongP2{}));
787  });
788  });
789  });
790 
791  static_for<0, KRepeat, 1>{}([&](auto k0) {
792  static_for<0, MRepeat, 1>{}([&](auto m0) {
793  static_for<0, NRepeat, 1>{}([&](auto n0) {
794  vector_type<FloatAB, KPack> a_thread_vec;
795  vector_type<FloatAB, KPack> b_thread_vec;
796 
797  static_for<0, KPack, 1>{}([&](auto ik) {
798  a_thread_vec.template AsType<FloatAB>()(ik) =
799  a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
800  make_tuple(m0, I0, k0, ik))>{}];
801  b_thread_vec.template AsType<FloatAB>()(ik) =
802  b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
803  make_tuple(n0, I0, k0, ik))>{}];
804  });
805 
806  using mfma_input_type =
807  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
808 
809  constexpr index_t c_offset =
810  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
811 
812  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
813  b_thread_vec.template AsType<mfma_input_type>(),
814  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
815  });
816  });
817  });
818 
819  TailScheduler<2>();
820  __builtin_amdgcn_sched_barrier(0);
821 
822  static_for<0, KRepeat, 1>{}([&](auto k) {
823  static_for<0, MRepeat, 1>{}([&](auto m0) {
824  static_for<0, NRepeat, 1>{}([&](auto n0) {
825  vector_type<FloatAB, KPack> a_thread_vec;
826  vector_type<FloatAB, KPack> b_thread_vec;
827 
828  static_for<0, KPack, 1>{}([&](auto ik) {
829  a_thread_vec.template AsType<FloatAB>()(ik) =
830  a_thread_bufs[PongP2{}][Number<a_thread_desc_.CalculateOffset(
831  make_tuple(m0, I0, k, ik))>{}];
832  b_thread_vec.template AsType<FloatAB>()(ik) =
833  b_thread_bufs[PongP2{}][Number<b_thread_desc_.CalculateOffset(
834  make_tuple(n0, I0, k, ik))>{}];
835  });
836 
837  using mfma_input_type =
838  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
839 
840  constexpr index_t c_offset =
841  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
842 
843  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
844  b_thread_vec.template AsType<mfma_input_type>(),
845  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
846  });
847  });
848  });
849 
850  // 64 v_mfma
851  __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
852  __builtin_amdgcn_sched_barrier(0);
853  }
854  else if constexpr(TailNum == 2)
855  {
856  using PingP1 = Number<0>;
857  using PongP1 = Number<1>;
858  // MFMA: Ping Reg
859  // DS_WRITE: To Ping LDS
860  // DS_READ: Pong LDS to Pong Reg
861  block_sync_lds();
862 
863  static_for<0, KRepeat, 1>{}([&](auto k) {
864  static_for<0, MRepeat, 1>{}([&](auto m0) {
865  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
866  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
867  a_block_buf.At(PongP1{}),
868  a_thread_desc_,
869  make_tuple(m0, I0, k, I0),
870  a_thread_bufs(PongP1{}));
871  static_for<0, NRepeat, 1>{}([&](auto n0) {
872  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
873  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
874  b_block_buf.At(PongP1{}),
875  b_thread_desc_,
876  make_tuple(n0, I0, k, I0),
877  b_thread_bufs(PongP1{}));
878  });
879  });
880  });
881 
882  static_for<0, KRepeat, 1>{}([&](auto k0) {
883  static_for<0, MRepeat, 1>{}([&](auto m0) {
884  static_for<0, NRepeat, 1>{}([&](auto n0) {
885  vector_type<FloatAB, KPack> a_thread_vec;
886  vector_type<FloatAB, KPack> b_thread_vec;
887 
888  static_for<0, KPack, 1>{}([&](auto ik) {
889  a_thread_vec.template AsType<FloatAB>()(ik) =
890  a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
891  make_tuple(m0, I0, k0, ik))>{}];
892  b_thread_vec.template AsType<FloatAB>()(ik) =
893  b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
894  make_tuple(n0, I0, k0, ik))>{}];
895  });
896 
897  using mfma_input_type =
898  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
899 
900  constexpr index_t c_offset =
901  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
902 
903  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
904  b_thread_vec.template AsType<mfma_input_type>(),
905  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
906  });
907  });
908  });
909 
910  TailScheduler<2>();
911  __builtin_amdgcn_sched_barrier(0);
912 
913  // -------------------------------------------------------------------------------------------
914  using PingP2 = Number<1>;
915  // MFMA: Pong Reg
916  // DS_WRITE: To Pong LDS
917  // DS_READ: Ping LDS to Ping Reg
918 
919  static_for<0, KRepeat, 1>{}([&](auto k0) {
920  static_for<0, MRepeat, 1>{}([&](auto m0) {
921  static_for<0, NRepeat, 1>{}([&](auto n0) {
922  vector_type<FloatAB, KPack> a_thread_vec;
923  vector_type<FloatAB, KPack> b_thread_vec;
924 
925  static_for<0, KPack, 1>{}([&](auto ik) {
926  a_thread_vec.template AsType<FloatAB>()(ik) =
927  a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
928  make_tuple(m0, I0, k0, ik))>{}];
929  b_thread_vec.template AsType<FloatAB>()(ik) =
930  b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
931  make_tuple(n0, I0, k0, ik))>{}];
932  });
933 
934  using mfma_input_type =
935  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
936 
937  constexpr index_t c_offset =
938  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
939 
940  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
941  b_thread_vec.template AsType<mfma_input_type>(),
942  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
943  });
944  });
945  });
946 
947  // 64 v_mfma
948  __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
949  __builtin_amdgcn_sched_barrier(0);
950  }
951  }
952 
953  protected:
954  // M1, N1 as double buffer index
955  // Read buffer + Compute buffer
956  // A[M0, M1, M2, KPack]
957  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
958  make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
959  make_tuple(
960  Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
961 
962  // B[N0, N1, N2, KPack]
963  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
964  make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
965  make_tuple(
966  Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
967 
968  // C[M, N, NumRegXdlops]
969  static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
970  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
971 
973  FloatAB,
974  decltype(a_block_desc_m0_m1_m2_k),
975  decltype(a_thread_desc_),
978  3,
979  A_K1,
980  A_K1>;
981 
983  FloatAB,
984  decltype(b_block_desc_n0_n1_n2_k),
985  decltype(b_thread_desc_),
988  3,
989  B_K1,
990  B_K1>;
991 
994 };
995 
996 } // namespace ck
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:190
Definition: ck.hpp:267
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops.hpp:37
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
static constexpr auto Print()
Definition: blockwise_gemm_pipeline_xdlops.hpp:57
static constexpr index_t WaveNumN
Definition: blockwise_gemm_pipeline_xdlops.hpp:36
static constexpr index_t WaveNumM
Definition: blockwise_gemm_pipeline_xdlops.hpp:35
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops.hpp:105
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_pipeline_xdlops.hpp:357
static constexpr index_t MWaves
Definition: blockwise_gemm_pipeline_xdlops.hpp:111
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops.hpp:255
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_pipeline_xdlops.hpp:340
static constexpr index_t A_K1
Definition: blockwise_gemm_pipeline_xdlops.hpp:119
static constexpr index_t A_K0
Definition: blockwise_gemm_pipeline_xdlops.hpp:117
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:963
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:310
static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:375
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops.hpp:115
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:993
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition: blockwise_gemm_pipeline_xdlops.hpp:234
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops.hpp:104
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:455
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops.hpp:179
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:992
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:269
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_pipeline_xdlops.hpp:109
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:456
static constexpr index_t KRepeat
Definition: blockwise_gemm_pipeline_xdlops.hpp:126
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_pipeline_xdlops.hpp:154
static constexpr auto I3
Definition: blockwise_gemm_pipeline_xdlops.hpp:107
static constexpr index_t B_K1
Definition: blockwise_gemm_pipeline_xdlops.hpp:120
static constexpr auto I2
Definition: blockwise_gemm_pipeline_xdlops.hpp:106
static constexpr index_t B_K0
Definition: blockwise_gemm_pipeline_xdlops.hpp:118
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:323
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops.hpp:168
static constexpr index_t KPerThread
Definition: blockwise_gemm_pipeline_xdlops.hpp:125
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_pipeline_xdlops.hpp:156
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:957
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops.hpp:296
static constexpr index_t NWaves
Definition: blockwise_gemm_pipeline_xdlops.hpp:112
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_pipeline_xdlops.hpp:145
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:282
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops.hpp:122
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops.hpp:221
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops.hpp:473
__host__ __device__ BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Definition: blockwise_gemm_pipeline_xdlops.hpp:237
static constexpr __device__ auto TailScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:406
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops.hpp:192
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: xdlops_gemm.hpp:1711
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33