/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp Source File
blockwise_gemm_xdlops.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 namespace ck {
13 
14 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
15 __host__ __device__ static constexpr auto
16 MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
17 {
18  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
19  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
20 
22  TileDesc_K0_MN_K1{},
23  make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
25  make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
26  make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
27  make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
28 }
29 
30 template <index_t BlockSize,
31  typename FloatA,
32  typename FloatB,
33  typename FloatAcc,
34  typename AK0MK1BlockDesc,
35  typename BK0NK1BlockDesc,
36  index_t MPerXDL,
37  index_t NPerXDL,
38  index_t MRepeat,
39  index_t NRepeat,
40  index_t KPack,
41  typename ComputeTypeA = FloatA,
42  typename ComputeTypeB = FloatB>
43 struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
44 {
45  static constexpr auto I0 = Number<0>{};
46  static constexpr auto I1 = Number<1>{};
47  static constexpr auto I2 = Number<2>{};
48  static constexpr auto I3 = Number<3>{};
49 
51 
56 
57  static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
58  static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
59  static constexpr index_t KPerBlock =
60  BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
61 
62  static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
63  static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
64  static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
65  static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
66 
67  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
68  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
69  static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
70 
71  static constexpr auto xdlops_gemm =
72  XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB, false, false>{};
73 
74  static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
75 
76  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
77  FloatAcc,
78  MRepeat * NRepeat,
79  xdlops_gemm.GetRegSizePerXdlops(),
80  true>
82 
83  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
84 
85  __device__ static auto GetWaveIdx()
86  {
87  const index_t thread_id = ThisThreadBlock::GetThreadId();
88 
89  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
93 
94  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
95  }
96 
97  __device__ static auto CalculateAThreadOriginDataIndex()
98  {
99  const auto wave_idx = GetWaveIdx();
100 
101  const auto waveId_m = wave_idx[I0];
102 
103  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
104 
105  return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
106  }
107 
108  __device__ static auto CalculateBThreadOriginDataIndex()
109  {
110  const auto wave_idx = GetWaveIdx();
111 
112  const auto waveId_n = wave_idx[I1];
113 
114  const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
115 
116  return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
117  }
118 
119  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
120  __device__ static auto
122  {
123  const auto wave_idx = GetWaveIdx();
124 
125  const auto waveId_m = wave_idx[I0];
126  const auto waveId_n = wave_idx[I1];
127 
128  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
129 
130  constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
131  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
134 
135  constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
136  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
139 
140  const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
141  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
142  const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
143  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
144 
145  return make_tuple(c_thread_m, c_thread_n);
146  }
147 
148  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
149  __device__ static auto
151  {
152  const auto wave_idx = GetWaveIdx();
153 
154  const auto waveId_m = wave_idx[I0];
155  const auto waveId_n = wave_idx[I1];
156 
157  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
158 
159  return make_tuple(Number<m0>{},
160  Number<n0>{},
161  waveId_m,
162  waveId_n,
163  blk_idx[I0],
164  blk_idx[I1],
165  blk_idx[I2],
166  blk_idx[I3]);
167  }
168 
170  {
171  static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
172  BK0NK1BlockDesc::IsKnownAtCompileTime(),
173  "wrong! Desc should be known at compile-time");
174 
175  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
176  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
177 
178  static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
179  "wrong!");
180  if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
181  {
182  static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
183  "ComputeTypeA and ComputeTypeB must be same when one of them is tf32");
184  }
185  }
186 
187  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
188  {
189  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
190 
191  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
192  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
193  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
194  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
195 
197  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
198  }
199 
200  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
201  {
202  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
203 
204  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
205  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
206  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
207  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
208 
210  make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
211  }
212 
213  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
214  {
215  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
217  Number<NRepeat>{},
218  Number<MWaves>{},
219  Number<NWaves>{},
220  Number<MPerXDL>{},
221  Number<NPerXDL>{}));
222 
223  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
224  }
225 
226  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
227  {
228  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
230  Number<MRepeat>{},
231  Number<NRepeat>{},
232  Number<MWaves>{},
233  Number<NWaves>{},
234  Number<MPerXDL>{},
235  Number<NPerXDL>{}));
236 
237  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
238  c_block_desc_g_m0_n0_m1_n1_m2_n2);
239  }
240 
241  template <typename CGridDesc_M_N>
242  __host__ __device__ static constexpr auto
243  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
244  {
245  const auto M = c_grid_desc_m_n.GetLength(I0);
246  const auto N = c_grid_desc_m_n.GetLength(I1);
247 
248  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
249  c_grid_desc_m_n,
250  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
251  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
254 
255  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
256  }
257 
258  template <typename CGridDesc_G_M_N>
259  __host__ __device__ static constexpr auto
260  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
261  {
262  const auto G = c_grid_desc_g_m_n.GetLength(I0);
263  const auto M = c_grid_desc_g_m_n.GetLength(I1);
264  const auto N = c_grid_desc_g_m_n.GetLength(I2);
265 
266  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
267  c_grid_desc_g_m_n,
269  make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
270  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
273 
274  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
275  c_grid_desc_g_m0_n0_m1_n1_m2_n2);
276  }
277 
278  __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
279  {
281  AK0MK1BlockDesc{},
282  make_tuple(
288  }
289 
290  __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
291  {
293  BK0NK1BlockDesc{},
294  make_tuple(
300  }
301 
304 
305  template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
306  __device__ void Run(const ABlockBuffer& a_block_buf,
307  const BBlockBuffer& b_block_buf,
308  CThreadBuffer& c_thread_buf) const
309  {
310  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ElementDataTypeA>(
311  a_thread_desc_.GetElementSpaceSize());
312  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ElementDataTypeB>(
313  b_thread_desc_.GetElementSpaceSize());
314 
315  static_for<0, MRepeat, 1>{}([&](auto m0) {
316  // read A
318  make_tuple(m0, I0, I0, I0),
319  a_block_buf,
321  make_tuple(I0, I0, I0, I0),
322  a_thread_buf);
323 
324  static_for<0, NRepeat, 1>{}([&](auto n0) {
325  // read B
327  make_tuple(n0, I0, I0, I0),
328  b_block_buf,
330  make_tuple(I0, I0, I0, I0),
331  b_thread_buf);
332 
333  static_for<0, KPerThread, KPack>{}([&](auto k) {
336 
337  static_for<0, KPack, 1>{}([&](auto i) {
338  a_thread_vec.template AsType<ElementDataTypeA>()(i) = a_thread_buf
339  [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
340  b_thread_vec.template AsType<ElementDataTypeB>()(i) = b_thread_buf
341  [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
342  });
343 
344  using mfma_input_type_a =
345  typename vector_type<ElementDataTypeA, xdlops_gemm.K1PerXdlops>::type;
346  using mfma_input_type_b =
347  typename vector_type<ElementDataTypeB, xdlops_gemm.K1PerXdlops>::type;
348 
349  constexpr index_t c_offset =
350  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
351 
352  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
353  b_thread_vec.template AsType<mfma_input_type_b>(),
354  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
355  });
356  });
357  });
358  }
359 
360  protected:
361  // A[M0, M1, M2, KPerThread]
362  static constexpr auto a_thread_desc_ =
364 
365  // B[N0, N1, N2, KPerThread]
366  static constexpr auto b_thread_desc_ =
367  make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
368 
369  // C[M, N, NumRegXdlops]
370  static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
371  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
372 
375  decltype(a_block_desc_m0_m1_m2_k),
376  decltype(a_thread_desc_),
379  3,
380  A_K1,
381  A_K1>;
382 
385  decltype(b_block_desc_n0_n1_n2_k),
386  decltype(b_thread_desc_),
389  3,
390  B_K1,
391  B_K1>;
392 
395 };
396 
397 // Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
398 // CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
399 // the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
400 // default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
401 template <index_t BlockSize,
402  typename FloatA,
403  typename FloatB,
404  typename FloatAcc,
405  typename AK0MK1BlockDesc,
406  typename BK0NK1BlockDesc,
407  index_t MPerXDL,
408  index_t NPerXDL,
409  index_t MRepeat,
410  index_t NRepeat,
411  index_t KPack,
412  typename ComputeTypeA = FloatA,
413  typename ComputeTypeB = FloatB,
417  FloatA,
418  FloatB,
419  FloatAcc,
420  AK0MK1BlockDesc,
421  BK0NK1BlockDesc,
422  MPerXDL,
423  NPerXDL,
424  MRepeat,
425  NRepeat,
426  KPack,
427  ComputeTypeA,
428  ComputeTypeB>
429 {
431  FloatA,
432  FloatB,
433  FloatAcc,
434  AK0MK1BlockDesc,
435  BK0NK1BlockDesc,
436  MPerXDL,
437  NPerXDL,
438  MRepeat,
439  NRepeat,
440  KPack,
441  ComputeTypeA,
442  ComputeTypeB>;
443 
444 #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
446  using Base::A_K1;
448  using Base::B_K1;
449  using Base::c_thread_buf_;
450  using Base::c_thread_desc_;
453  using Base::I0;
454  using Base::I1;
455  using Base::KPerThread;
456  using Base::xdlops_gemm;
457 
458  using ElementDataTypeA =
460  using ElementDataTypeB =
462 
463  static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
464 
465  // 2-wave optimized blockwise gemm
466  template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
467  __device__ void Run(const ABlockBuffer& a_block_buf,
468  const BBlockBuffer& b_block_buf,
469  CThreadBuffer& c_thread_buf) const
470  {
471  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ElementDataTypeA>(
472  a_thread_desc_.GetElementSpaceSize());
473  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ElementDataTypeB>(
474  b_thread_desc_.GetElementSpaceSize());
475 
477  static_for<0, MRepeat, 1>{}([&](auto m0) {
478  // read A
480  make_tuple(m0, I0, I0, k),
481  a_block_buf,
483  make_tuple(m0, I0, I0, I0),
484  a_thread_buf);
485  });
486  static_for<0, NRepeat, 1>{}([&](auto n0) {
487  // read B
489  make_tuple(n0, I0, I0, k),
490  b_block_buf,
492  make_tuple(n0, I0, I0, I0),
493  b_thread_buf);
494  });
495  __builtin_amdgcn_sched_barrier(0);
496  // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except
497  // the first, as we can shorten non-MAC cluster a bit and there's no observable negative
498  // impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids
499  // some out-of-sync waves hijacking MAC resource from other workgroups and reducing the
500  // chance of latency hiding by waiting for the rest of the workgroup at the eventual
501  // sync point.
502  if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
503  {
504 #ifdef __gfx12__
505  asm volatile("\
506  s_barrier_signal -1 \n \
507  s_barrier_wait -1 \
508  " ::);
509 #else
510  asm volatile("s_barrier" ::);
511 #endif
512  __builtin_amdgcn_sched_barrier(0);
513  }
514  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
515  static_for<0, MRepeat, 1>{}([&](auto m0) {
516  static_for<0, NRepeat, 1>{}([&](auto n0) {
517  vector_type<ElementDataTypeA, KPack> a_thread_vec;
518  vector_type<ElementDataTypeB, KPack> b_thread_vec;
519 
520  static_for<0, KPack, 1>{}([&](auto i) {
521  a_thread_vec.template AsType<ElementDataTypeA>()(i) =
522  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
523  make_tuple(m0, 0, 0, k_ + i))>{}];
524  b_thread_vec.template AsType<ElementDataTypeB>()(i) =
525  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
526  make_tuple(n0, 0, 0, k_ + i))>{}];
527  });
528 
529  using mfma_input_type_a =
530  typename vector_type<ElementDataTypeA, xdlops_gemm.K1PerXdlops>::type;
531  using mfma_input_type_b =
532  typename vector_type<ElementDataTypeB, xdlops_gemm.K1PerXdlops>::type;
533 
534  constexpr index_t c_offset =
535  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
536 
537  // The block_sync_lds() here performs double duty:
538  // A) safeguard against data hazard because barrier from blockwise_gemm is
539  // moved here B) reduce VMEM FIFO congestion by applying small delays to
540  // different wavefronts It is performed near the end of MAC cluster to
541  // minimize lgkmcnt penalty
542  if constexpr(k.value == KPerThread - KPerInnerLoop &&
543  k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 &&
544  n0.value == NRepeat - 1)
545  {
546  __builtin_amdgcn_sched_barrier(0);
547  block_sync_lds();
548  __builtin_amdgcn_sched_barrier(0);
549  }
550 
551  // TODO: insert setprio in more precise manner since we
552  // could have more than >1 MFMA instructions in single call
553  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
554  b_thread_vec.template AsType<mfma_input_type_b>(),
555  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
556  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
557  {
558  __builtin_amdgcn_sched_barrier(0);
559  __builtin_amdgcn_s_setprio(1);
560  __builtin_amdgcn_sched_barrier(0);
561  }
562  });
563  });
564  });
565  __builtin_amdgcn_sched_barrier(0);
566  __builtin_amdgcn_s_setprio(0);
567  __builtin_amdgcn_sched_barrier(0);
568  });
569  }
570 
571  protected:
572  // A[M0, M1, M2, KPerInnerLoop]
573  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
574  make_tuple(Number<MRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
575 
576  // B[N0, N1, N2, KPerInnerLoop]
577  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
578  make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
579 
580  using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
582  decltype(a_block_desc_m0_m1_m2_k),
583  decltype(a_thread_desc_),
584  Sequence<1, 1, 1, KPerInnerLoop>,
585  Sequence<0, 1, 2, 3>,
586  3,
587  A_K1,
588  A_K1>;
589 
590  using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
592  decltype(b_block_desc_n0_n1_n2_k),
593  decltype(b_thread_desc_),
594  Sequence<1, 1, 1, KPerInnerLoop>,
595  Sequence<0, 1, 2, 3>,
596  3,
597  B_K1,
598  B_K1>;
599 
602 
603 #endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
604 };
605 
606 template <index_t BlockSize,
607  typename FloatA,
608  typename FloatB,
609  typename FloatAcc,
610  typename AK0MK1BlockDesc,
611  typename BK0NK1BlockDesc,
612  index_t MPerXDL,
613  index_t NPerXDL,
614  index_t MRepeat,
615  index_t NRepeat,
616  index_t KPack,
617  LoopScheduler LoopSched,
618  typename ComputeTypeA = FloatA,
619  typename ComputeTypeB = FloatB>
621 {
622  if constexpr(LoopSched == LoopScheduler::Default)
623  {
625  FloatA,
626  FloatB,
627  FloatAcc,
628  AK0MK1BlockDesc,
629  BK0NK1BlockDesc,
630  MPerXDL,
631  NPerXDL,
632  MRepeat,
633  NRepeat,
634  KPack,
635  ComputeTypeA,
636  ComputeTypeB>{};
637  }
638  else if constexpr(LoopSched == LoopScheduler::Interwave)
639  {
641  BlockSize,
642  FloatA,
643  FloatB,
644  FloatAcc,
645  AK0MK1BlockDesc,
646  BK0NK1BlockDesc,
647  MPerXDL,
648  NPerXDL,
649  MRepeat,
650  NRepeat,
651  KPack,
652  ComputeTypeA,
653  ComputeTypeB,
655  }
656 };
657 
668 template <
669  index_t BlockSize,
670  typename FloatAB,
671  typename FloatAcc,
672  typename ATileDesc,
673  typename BTileDesc,
674  typename AMmaTileDesc,
675  typename BMmaTileDesc,
676  index_t MPerBlock,
677  index_t NPerBlock,
678  index_t KPerBlock,
679  index_t MPerXDL,
680  index_t NPerXDL,
681  index_t MRepeat,
682  index_t NRepeat,
683  index_t KPack,
684  bool TransposeC = false,
685  index_t AMmaKStride =
686  KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
687  index_t BMmaKStride =
688  KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
690 {
691  static constexpr auto I0 = Number<0>{};
692  static constexpr auto I1 = Number<1>{};
693  static constexpr auto I2 = Number<2>{};
694  static constexpr auto I3 = Number<3>{};
695 
697 
698  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
699  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
700  static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
701 
702  static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
703  static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
704  static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
705  static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
706 
707  static constexpr auto xdlops_gemm =
709 
710  static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
711 
712  static_assert(KPerThread % KPack == 0,
713  "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
714 
716  FloatAcc,
717  MRepeat * NRepeat,
718  xdlops_gemm.GetRegSizePerXdlops(),
719  true>
721 
722  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
723 
724  __device__ static auto GetWaveIdx()
725  {
726  const index_t thread_id = ThisThreadBlock::GetThreadId();
727 
728  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
732 
733  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
734  }
735 
736  __device__ static auto CalculateAThreadOriginDataIndex()
737  {
738  const auto wave_idx = GetWaveIdx();
739 
740  const auto waveId_m = wave_idx[I0];
741 
742  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
743 
744  return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
745  }
746 
747  __device__ static auto CalculateBThreadOriginDataIndex()
748  {
749  const auto wave_idx = GetWaveIdx();
750 
751  const auto waveId_n = wave_idx[I1];
752 
753  const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
754 
755  return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
756  }
757 
758  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
759  __device__ static auto
761  {
762  const auto wave_idx = GetWaveIdx();
763 
764  const auto waveId_m = wave_idx[I0];
765  const auto waveId_n = wave_idx[I1];
766 
767  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
768 
769  constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
770  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
773 
774  constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
775  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
778 
779  const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
780  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
781  const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
782  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
783 
784  return make_tuple(c_thread_m, c_thread_n);
785  }
786 
787  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
788  __device__ static auto
790  {
791  const auto wave_idx = GetWaveIdx();
792 
793  const auto waveId_m = wave_idx[I0];
794  const auto waveId_n = wave_idx[I1];
795 
796  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
797 
798  return make_tuple(
799  m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
800  }
801 
803 
806  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
807  {
808 #if defined(__HIP_DEVICE_COMPILE__)
809  static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
810  "wrong! Desc should be known at compile-time");
811 
812  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
813  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
814 
815  static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
816  "wrong!");
817 #endif
818  }
819 
820  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
821  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
822  {
823  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
824 
825  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
826  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
827  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
828  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
829 
831  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
832  }
833 
834  // XDL output supporting C_xdl = A_xdl * B_xdl
835  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
836  {
837  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
838 
839  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
840  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
841  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
842  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
843 
845  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
846  }
847 
848  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
849  {
850  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
851 
852  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
853  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
854  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
855  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
856 
858  make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
859  }
860 
861  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
862  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
863  {
864  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
866  Number<NRepeat>{},
867  Number<MWaves>{},
868  Number<NWaves>{},
869  Number<MPerXDL>{},
870  Number<NPerXDL>{}));
871 
872  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
873  }
874 
875  // XDL output supporting C_xdl = A_xdl * B_xdl
876  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
877  {
878  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
880  Number<NRepeat>{},
881  Number<MWaves>{},
882  Number<NWaves>{},
883  Number<MPerXDL>{},
884  Number<NPerXDL>{}));
885 
886  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
887  }
888 
889  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
890  {
891  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
893  Number<MRepeat>{},
894  Number<NRepeat>{},
895  Number<MWaves>{},
896  Number<NWaves>{},
897  Number<MPerXDL>{},
898  Number<NPerXDL>{}));
899 
900  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
901  c_block_desc_g_m0_n0_m1_n1_m2_n2);
902  }
903 
904  template <typename CGridDesc_M_N>
905  __host__ __device__ static constexpr auto
906  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
907  {
908  const auto M = c_grid_desc_m_n.GetLength(I0);
909  const auto N = c_grid_desc_m_n.GetLength(I1);
910 
911  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
912  c_grid_desc_m_n,
913  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
914  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
917 
918  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
919  }
920 
921  template <typename CGridDesc_G_M_N>
922  __host__ __device__ static constexpr auto
923  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
924  {
925  const auto G = c_grid_desc_g_m_n.GetLength(I0);
926  const auto M = c_grid_desc_g_m_n.GetLength(I1);
927  const auto N = c_grid_desc_g_m_n.GetLength(I2);
928 
929  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
930  c_grid_desc_g_m_n,
932  make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
933  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
936 
937  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
938  c_grid_desc_g_m0_n0_m1_n1_m2_n2);
939  }
940 
941  static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
942  static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
943 
944  template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
945  __device__ void Run(const ABlockBuffer& a_block_buf,
946  const BBlockBuffer& b_block_buf,
947  CThreadBuffer& c_thread_buf) const
948  {
949  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
950  a_thread_desc_.GetElementSpaceSize());
951  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
952  b_thread_desc_.GetElementSpaceSize());
953 
954  static_for<0, KPerThread / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
955  static_for<0, MRepeat, 1>{}([&](auto m0) {
956  // read A
959  a_block_buf,
961  make_tuple(I0, I0, I0, I0),
962  a_thread_buf);
963 
964  static_for<0, NRepeat, 1>{}([&](auto n0) {
965  // read B
968  b_block_buf,
970  make_tuple(I0, I0, I0, I0),
971  b_thread_buf);
972  vector_type<FloatAB, KPack> a_thread_vec;
973  vector_type<FloatAB, KPack> b_thread_vec;
974 
975  static_for<0, KPack, 1>{}([&](auto i) {
976  a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
977  [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
978  b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
979  [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
980  });
981 
982  using mfma_input_type =
983  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
984 
985  constexpr index_t c_offset =
986  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
987 
988  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
989  b_thread_vec.template AsType<mfma_input_type>(),
990  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
991  });
992  });
993  });
994  }
995 
996  protected:
997  // A[M0, M1, M2, KPack]
998  static constexpr auto a_thread_desc_ =
1000 
1001  // B[N0, N1, N2, KPack]
1002  static constexpr auto b_thread_desc_ =
1004 
1005  // C[M, N, NumRegXdlops]
1007  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
1008 
1010  FloatAB,
1011  decltype(a_block_desc_m0_m1_m2_k),
1012  decltype(a_thread_desc_),
1015  3,
1016  A_K1,
1017  A_K1>;
1018 
1020  FloatAB,
1021  decltype(b_block_desc_n0_n1_n2_k),
1022  decltype(b_thread_desc_),
1025  3,
1026  B_K1,
1027  B_K1>;
1028 
1031 };
1032 
1033 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:209
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:620
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
static constexpr index_t KPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:58
static constexpr index_t A_K1
Definition: blockwise_gemm_smfmac_xdlops.hpp:63
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:427
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
Definition: blockwise_gemm_xdlops.hpp:169
static constexpr auto I2
Definition: blockwise_gemm_smfmac_xdlops.hpp:47
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:108
static constexpr index_t WaveSize
Definition: blockwise_gemm_smfmac_xdlops.hpp:54
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_xdlops.hpp:260
static constexpr index_t KPerThread
Definition: blockwise_gemm_smfmac_xdlops.hpp:69
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_xdlops.hpp:243
static constexpr index_t B_K1
Definition: blockwise_gemm_smfmac_xdlops.hpp:64
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:187
conditional_t< is_same_v< ComputeTypeA, ck::tf32_t >, float, ComputeTypeA > ElementDataTypeA
Definition: blockwise_gemm_xdlops.hpp:53
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:97
static constexpr index_t MPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:56
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_smfmac_xdlops.hpp:76
static constexpr auto b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_smfmac_xdlops.hpp:295
static constexpr index_t NPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:57
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:150
static constexpr auto I0
Definition: blockwise_gemm_smfmac_xdlops.hpp:45
conditional_t< is_same_v< ComputeTypeB, ck::tf32_t >, float, ComputeTypeB > ElementDataTypeB
Definition: blockwise_gemm_xdlops.hpp:55
ThreadwiseTensorSliceTransfer_v4< FloatA, FloatA, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition: blockwise_gemm_smfmac_xdlops.hpp:438
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:419
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_smfmac_xdlops.hpp:451
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:50
__host__ static constexpr __device__ auto MakeBBlockDescriptor_N0_N1_N2_K()
Definition: blockwise_gemm_xdlops.hpp:290
__host__ static constexpr __device__ auto MakeABlockDescriptor_M0_M1_M2_K()
Definition: blockwise_gemm_xdlops.hpp:278
static constexpr auto a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_smfmac_xdlops.hpp:294
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_smfmac_xdlops.hpp:450
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_xdlops.hpp:83
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:226
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:200
static constexpr index_t NWaves
Definition: blockwise_gemm_smfmac_xdlops.hpp:53
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_smfmac_xdlops.hpp:66
static constexpr index_t B_K0
Definition: blockwise_gemm_smfmac_xdlops.hpp:62
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:423
static constexpr index_t A_K0
Definition: blockwise_gemm_smfmac_xdlops.hpp:61
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_xdlops.hpp:306
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:213
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:121
ThreadwiseTensorSliceTransfer_v4< FloatB, ComputeTypeB, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition: blockwise_gemm_smfmac_xdlops.hpp:448
static constexpr auto I3
Definition: blockwise_gemm_smfmac_xdlops.hpp:48
static constexpr auto I1
Definition: blockwise_gemm_smfmac_xdlops.hpp:46
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_xdlops.hpp:85
static constexpr index_t MWaves
Definition: blockwise_gemm_smfmac_xdlops.hpp:52
Blockwise gemm.
Definition: blockwise_gemm_xdlops.hpp:690
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:835
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:848
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_xdlops.hpp:722
static constexpr index_t A_K0
Definition: blockwise_gemm_xdlops.hpp:702
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_xdlops.hpp:707
static constexpr index_t A_K1
Definition: blockwise_gemm_xdlops.hpp:704
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_xdlops.hpp:1002
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_xdlops.hpp:724
static constexpr auto I1
Definition: blockwise_gemm_xdlops.hpp:692
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:876
static constexpr index_t NWaves
Definition: blockwise_gemm_xdlops.hpp:699
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_xdlops.hpp:942
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_xdlops.hpp:862
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_xdlops.hpp:945
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Definition: blockwise_gemm_xdlops.hpp:804
static constexpr index_t B_K0
Definition: blockwise_gemm_xdlops.hpp:703
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:760
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_xdlops.hpp:906
static constexpr auto I2
Definition: blockwise_gemm_xdlops.hpp:693
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:889
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition: blockwise_gemm_xdlops.hpp:802
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_xdlops.hpp:998
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_xdlops.hpp:1006
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_xdlops.hpp:821
static constexpr auto I3
Definition: blockwise_gemm_xdlops.hpp:694
static constexpr index_t WaveSize
Definition: blockwise_gemm_xdlops.hpp:700
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_xdlops.hpp:923
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:736
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:747
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_xdlops.hpp:941
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_xdlops.hpp:696
static constexpr index_t MWaves
Definition: blockwise_gemm_xdlops.hpp:698
static constexpr index_t B_K1
Definition: blockwise_gemm_xdlops.hpp:705
static constexpr index_t KPerThread
Definition: blockwise_gemm_xdlops.hpp:710
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_xdlops.hpp:1029
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_xdlops.hpp:713
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:789
static constexpr auto I0
Definition: blockwise_gemm_xdlops.hpp:691
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_xdlops.hpp:1030
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
Definition: threadwise_tensor_slice_transfer.hpp:1260
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: xdlops_gemm.hpp:1821
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10