/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp Source File
blockwise_gemm_xdlops.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 namespace ck {
13 
14 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
15 __host__ __device__ static constexpr auto
16 MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
17 {
18  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
19  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
20 
22  TileDesc_K0_MN_K1{},
23  make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
25  make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
26  make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
27  make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
28 }
29 
30 template <index_t BlockSize,
31  typename FloatA,
32  typename FloatB,
33  typename FloatAcc,
34  typename AK0MK1BlockDesc,
35  typename BK0NK1BlockDesc,
36  index_t MPerXDL,
37  index_t NPerXDL,
38  index_t MRepeat,
39  index_t NRepeat,
40  index_t KPack,
41  typename ComputeTypeA = FloatA,
42  typename ComputeTypeB = FloatB>
43 struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
44 {
45  static constexpr auto I0 = Number<0>{};
46  static constexpr auto I1 = Number<1>{};
47  static constexpr auto I2 = Number<2>{};
48  static constexpr auto I3 = Number<3>{};
49 
51 
52  static constexpr index_t WaveSize = get_warp_size();
53 
54  static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
55  static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
56  static constexpr index_t KPerBlock =
57  BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
58 
59  static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
60  static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
61  static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
62  static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
63 
64  static constexpr auto xdlops_gemm =
65  XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB>{};
66 
67  static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
68 
69  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
70  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
71 
72  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
73  FloatAcc,
74  MRepeat * NRepeat,
75  xdlops_gemm.GetRegSizePerXdlops(),
76  true>
78 
79  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
80 
81  __device__ static auto GetWaveIdx()
82  {
83  const index_t thread_id = ThisThreadBlock::GetThreadId();
84 
85  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
89 
90  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
91  }
92 
93  __device__ static auto CalculateAThreadOriginDataIndex()
94  {
95  const auto wave_idx = GetWaveIdx();
96 
97  const auto waveId_m = wave_idx[I0];
98 
99  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
100 
101  return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
102  }
103 
104  __device__ static auto CalculateBThreadOriginDataIndex()
105  {
106  const auto wave_idx = GetWaveIdx();
107 
108  const auto waveId_n = wave_idx[I1];
109 
110  const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
111 
112  return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
113  }
114 
115  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
116  __device__ static auto
118  {
119  const auto wave_idx = GetWaveIdx();
120 
121  const auto waveId_m = wave_idx[I0];
122  const auto waveId_n = wave_idx[I1];
123 
124  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
125 
126  constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
127  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
130 
131  constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
132  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
135 
136  const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
137  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
138  const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
139  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
140 
141  return make_tuple(c_thread_m, c_thread_n);
142  }
143 
144  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
145  __device__ static auto
147  {
148  const auto wave_idx = GetWaveIdx();
149 
150  const auto waveId_m = wave_idx[I0];
151  const auto waveId_n = wave_idx[I1];
152 
153  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
154 
155  return make_tuple(Number<m0>{},
156  Number<n0>{},
157  waveId_m,
158  waveId_n,
159  blk_idx[I0],
160  blk_idx[I1],
161  blk_idx[I2],
162  blk_idx[I3]);
163  }
164 
166  {
167  static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
168  BK0NK1BlockDesc::IsKnownAtCompileTime(),
169  "wrong! Desc should be known at compile-time");
170 
171  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
172  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
173 
174  static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
175  "wrong!");
176  }
177 
178  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
179  {
180  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
181 
182  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
183  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
184  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
185  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
186 
188  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
189  }
190 
191  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
192  {
193  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
194 
195  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
196  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
197  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
198  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
199 
201  make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
202  }
203 
204  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
205  {
206  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
208  Number<NRepeat>{},
209  Number<MWaves>{},
210  Number<NWaves>{},
211  Number<MPerXDL>{},
212  Number<NPerXDL>{}));
213 
214  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
215  }
216 
217  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
218  {
219  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
221  Number<MRepeat>{},
222  Number<NRepeat>{},
223  Number<MWaves>{},
224  Number<NWaves>{},
225  Number<MPerXDL>{},
226  Number<NPerXDL>{}));
227 
228  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
229  c_block_desc_g_m0_n0_m1_n1_m2_n2);
230  }
231 
232  template <typename CGridDesc_M_N>
233  __host__ __device__ static constexpr auto
234  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
235  {
236  const auto M = c_grid_desc_m_n.GetLength(I0);
237  const auto N = c_grid_desc_m_n.GetLength(I1);
238 
239  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
240  c_grid_desc_m_n,
241  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
242  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
245 
246  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
247  }
248 
249  template <typename CGridDesc_G_M_N>
250  __host__ __device__ static constexpr auto
251  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
252  {
253  const auto G = c_grid_desc_g_m_n.GetLength(I0);
254  const auto M = c_grid_desc_g_m_n.GetLength(I1);
255  const auto N = c_grid_desc_g_m_n.GetLength(I2);
256 
257  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
258  c_grid_desc_g_m_n,
260  make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
261  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
264 
265  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
266  c_grid_desc_g_m0_n0_m1_n1_m2_n2);
267  }
268 
269  __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
270  {
272  AK0MK1BlockDesc{},
273  make_tuple(
279  }
280 
281  __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
282  {
284  BK0NK1BlockDesc{},
285  make_tuple(
291  }
292 
295 
296  template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
297  __device__ void Run(const ABlockBuffer& a_block_buf,
298  const BBlockBuffer& b_block_buf,
299  CThreadBuffer& c_thread_buf) const
300  {
301  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
302  a_thread_desc_.GetElementSpaceSize());
303  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
304  b_thread_desc_.GetElementSpaceSize());
305 
306  static_for<0, MRepeat, 1>{}([&](auto m0) {
307  // read A
309  make_tuple(m0, I0, I0, I0),
310  a_block_buf,
312  make_tuple(I0, I0, I0, I0),
313  a_thread_buf);
314 
315  static_for<0, NRepeat, 1>{}([&](auto n0) {
316  // read B
318  make_tuple(n0, I0, I0, I0),
319  b_block_buf,
321  make_tuple(I0, I0, I0, I0),
322  b_thread_buf);
323 
324  static_for<0, KPerThread, KPack>{}([&](auto k) {
327 
328  static_for<0, KPack, 1>{}([&](auto i) {
329  a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf
330  [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
331  b_thread_vec.template AsType<ComputeTypeB>()(i) = b_thread_buf
332  [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
333  });
334 
335  using mfma_input_type_a =
336  typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
337  using mfma_input_type_b =
338  typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
339 
340  constexpr index_t c_offset =
341  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
342 
343  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
344  b_thread_vec.template AsType<mfma_input_type_b>(),
345  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
346  });
347  });
348  });
349  }
350 
351  protected:
352  // A[M0, M1, M2, KPerThread]
353  static constexpr auto a_thread_desc_ =
355 
356  // B[N0, N1, N2, KPerThread]
357  static constexpr auto b_thread_desc_ =
358  make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
359 
360  // C[M, N, NumRegXdlops]
361  static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
362  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
363 
365  ComputeTypeA,
366  decltype(a_block_desc_m0_m1_m2_k),
367  decltype(a_thread_desc_),
370  3,
371  A_K1,
372  A_K1>;
373 
375  ComputeTypeB,
376  decltype(b_block_desc_n0_n1_n2_k),
377  decltype(b_thread_desc_),
380  3,
381  B_K1,
382  B_K1>;
383 
386 };
387 
388 // Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
389 // CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
390 // the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
391 // default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
392 template <index_t BlockSize,
393  typename FloatA,
394  typename FloatB,
395  typename FloatAcc,
396  typename AK0MK1BlockDesc,
397  typename BK0NK1BlockDesc,
398  index_t MPerXDL,
399  index_t NPerXDL,
400  index_t MRepeat,
401  index_t NRepeat,
402  index_t KPack,
403  typename ComputeTypeA = FloatA,
404  typename ComputeTypeB = FloatB,
408  FloatA,
409  FloatB,
410  FloatAcc,
411  AK0MK1BlockDesc,
412  BK0NK1BlockDesc,
413  MPerXDL,
414  NPerXDL,
415  MRepeat,
416  NRepeat,
417  KPack,
418  ComputeTypeA,
419  ComputeTypeB>
420 {
422  FloatA,
423  FloatB,
424  FloatAcc,
425  AK0MK1BlockDesc,
426  BK0NK1BlockDesc,
427  MPerXDL,
428  NPerXDL,
429  MRepeat,
430  NRepeat,
431  KPack,
432  ComputeTypeA,
433  ComputeTypeB>;
434 
435 #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
437  using Base::A_K1;
439  using Base::B_K1;
440  using Base::c_thread_buf_;
441  using Base::c_thread_desc_;
444  using Base::I0;
445  using Base::I1;
446  using Base::KPerThread;
447  using Base::xdlops_gemm;
448 
449  static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
450 
451  // 2-wave optimized blockwise gemm
452  template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
453  __device__ void Run(const ABlockBuffer& a_block_buf,
454  const BBlockBuffer& b_block_buf,
455  CThreadBuffer& c_thread_buf) const
456  {
457  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
458  a_thread_desc_.GetElementSpaceSize());
459  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
460  b_thread_desc_.GetElementSpaceSize());
461 
463  static_for<0, MRepeat, 1>{}([&](auto m0) {
464  // read A
466  make_tuple(m0, I0, I0, k),
467  a_block_buf,
469  make_tuple(m0, I0, I0, I0),
470  a_thread_buf);
471  });
472  static_for<0, NRepeat, 1>{}([&](auto n0) {
473  // read B
475  make_tuple(n0, I0, I0, k),
476  b_block_buf,
478  make_tuple(n0, I0, I0, I0),
479  b_thread_buf);
480  });
481  __builtin_amdgcn_sched_barrier(0);
482  // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except
483  // the first, as we can shorten non-MAC cluster a bit and there's no observable negative
484  // impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids
485  // some out-of-sync waves hijacking MAC resource from other workgroups and reducing the
486  // chance of latency hiding by waiting for the rest of the workgroup at the eventual
487  // sync point.
488  if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
489  {
490 #ifdef __gfx12__
491  asm volatile("\
492  s_barrier_signal -1 \n \
493  s_barrier_wait -1 \
494  " ::);
495 #else
496  asm volatile("s_barrier" ::);
497 #endif
498  __builtin_amdgcn_sched_barrier(0);
499  }
500  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
501  static_for<0, MRepeat, 1>{}([&](auto m0) {
502  static_for<0, NRepeat, 1>{}([&](auto n0) {
503  vector_type<ComputeTypeA, KPack> a_thread_vec;
504  vector_type<ComputeTypeB, KPack> b_thread_vec;
505 
506  static_for<0, KPack, 1>{}([&](auto i) {
507  a_thread_vec.template AsType<ComputeTypeA>()(i) =
508  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
509  make_tuple(m0, 0, 0, k_ + i))>{}];
510  b_thread_vec.template AsType<ComputeTypeB>()(i) =
511  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
512  make_tuple(n0, 0, 0, k_ + i))>{}];
513  });
514 
515  using mfma_input_type_a =
516  typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
517  using mfma_input_type_b =
518  typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
519 
520  constexpr index_t c_offset =
521  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
522 
523  // The block_sync_lds() here performs double duty:
524  // A) safeguard against data hazard because barrier from blockwise_gemm is
525  // moved here B) reduce VMEM FIFO congestion by applying small delays to
526  // different wavefronts It is performed near the end of MAC cluster to
527  // minimize lgkmcnt penalty
528  if constexpr(k.value == KPerThread - KPerInnerLoop &&
529  k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 &&
530  n0.value == NRepeat - 1)
531  {
532  __builtin_amdgcn_sched_barrier(0);
533  block_sync_lds();
534  __builtin_amdgcn_sched_barrier(0);
535  }
536 
537  // TODO: insert setprio in more precise manner since we
538  // could have more than >1 MFMA instructions in single call
539  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
540  b_thread_vec.template AsType<mfma_input_type_b>(),
541  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
542  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
543  {
544  __builtin_amdgcn_sched_barrier(0);
545  __builtin_amdgcn_s_setprio(1);
546  __builtin_amdgcn_sched_barrier(0);
547  }
548  });
549  });
550  });
551  __builtin_amdgcn_sched_barrier(0);
552  __builtin_amdgcn_s_setprio(0);
553  __builtin_amdgcn_sched_barrier(0);
554  });
555  }
556 
557  protected:
558  // A[M0, M1, M2, KPerInnerLoop]
559  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
560  make_tuple(Number<MRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
561 
562  // B[N0, N1, N2, KPerInnerLoop]
563  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
564  make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
565 
566  using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
567  ComputeTypeA,
568  decltype(a_block_desc_m0_m1_m2_k),
569  decltype(a_thread_desc_),
570  Sequence<1, 1, 1, KPerInnerLoop>,
571  Sequence<0, 1, 2, 3>,
572  3,
573  A_K1,
574  A_K1>;
575 
576  using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
577  ComputeTypeB,
578  decltype(b_block_desc_n0_n1_n2_k),
579  decltype(b_thread_desc_),
580  Sequence<1, 1, 1, KPerInnerLoop>,
581  Sequence<0, 1, 2, 3>,
582  3,
583  B_K1,
584  B_K1>;
585 
588 
589 #endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
590 };
591 
592 template <index_t BlockSize,
593  typename FloatA,
594  typename FloatB,
595  typename FloatAcc,
596  typename AK0MK1BlockDesc,
597  typename BK0NK1BlockDesc,
598  index_t MPerXDL,
599  index_t NPerXDL,
600  index_t MRepeat,
601  index_t NRepeat,
602  index_t KPack,
603  LoopScheduler LoopSched,
604  typename ComputeTypeA = FloatA,
605  typename ComputeTypeB = FloatB>
607 {
608  if constexpr(LoopSched == LoopScheduler::Default)
609  {
611  FloatA,
612  FloatB,
613  FloatAcc,
614  AK0MK1BlockDesc,
615  BK0NK1BlockDesc,
616  MPerXDL,
617  NPerXDL,
618  MRepeat,
619  NRepeat,
620  KPack,
621  ComputeTypeA,
622  ComputeTypeB>{};
623  }
624  else if constexpr(LoopSched == LoopScheduler::Interwave)
625  {
627  FloatA,
628  FloatB,
629  FloatAcc,
630  AK0MK1BlockDesc,
631  BK0NK1BlockDesc,
632  MPerXDL,
633  NPerXDL,
634  MRepeat,
635  NRepeat,
636  KPack,
637  ComputeTypeA,
638  ComputeTypeB>{};
639  }
640 };
641 
652 template <
653  index_t BlockSize,
654  typename FloatAB,
655  typename FloatAcc,
656  typename ATileDesc,
657  typename BTileDesc,
658  typename AMmaTileDesc,
659  typename BMmaTileDesc,
660  index_t MPerBlock,
661  index_t NPerBlock,
662  index_t KPerBlock,
663  index_t MPerXDL,
664  index_t NPerXDL,
665  index_t MRepeat,
666  index_t NRepeat,
667  index_t KPack,
668  bool TransposeC = false,
669  index_t AMmaKStride =
670  KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
671  index_t BMmaKStride =
672  KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
674 {
675  static constexpr auto I0 = Number<0>{};
676  static constexpr auto I1 = Number<1>{};
677  static constexpr auto I2 = Number<2>{};
678  static constexpr auto I3 = Number<3>{};
679 
681 
682  static constexpr index_t WaveSize = get_warp_size();
683 
684  static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
685  static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
686  static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
687  static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
688 
689  static constexpr auto xdlops_gemm =
691 
692  static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
693 
694  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
695  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
696 
697  static_assert(KPerThread % KPack == 0,
698  "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
699 
701  FloatAcc,
702  MRepeat * NRepeat,
703  xdlops_gemm.GetRegSizePerXdlops(),
704  true>
706 
707  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
708 
709  __device__ static auto GetWaveIdx()
710  {
711  const index_t thread_id = ThisThreadBlock::GetThreadId();
712 
713  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
717 
718  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
719  }
720 
721  __device__ static auto CalculateAThreadOriginDataIndex()
722  {
723  const auto wave_idx = GetWaveIdx();
724 
725  const auto waveId_m = wave_idx[I0];
726 
727  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
728 
729  return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
730  }
731 
732  __device__ static auto CalculateBThreadOriginDataIndex()
733  {
734  const auto wave_idx = GetWaveIdx();
735 
736  const auto waveId_n = wave_idx[I1];
737 
738  const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
739 
740  return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
741  }
742 
743  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
744  __device__ static auto
746  {
747  const auto wave_idx = GetWaveIdx();
748 
749  const auto waveId_m = wave_idx[I0];
750  const auto waveId_n = wave_idx[I1];
751 
752  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
753 
754  constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
755  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
758 
759  constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
760  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
763 
764  const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
765  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
766  const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
767  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
768 
769  return make_tuple(c_thread_m, c_thread_n);
770  }
771 
772  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
773  __device__ static auto
775  {
776  const auto wave_idx = GetWaveIdx();
777 
778  const auto waveId_m = wave_idx[I0];
779  const auto waveId_n = wave_idx[I1];
780 
781  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
782 
783  return make_tuple(
784  m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
785  }
786 
788 
791  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
792  {
793  static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
794  "wrong! Desc should be known at compile-time");
795 
796  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
797  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
798 
799  static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
800  "wrong!");
801  }
802 
803  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
804  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
805  {
806  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
807 
808  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
809  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
810  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
811  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
812 
814  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
815  }
816 
817  // XDL output supporting C_xdl = A_xdl * B_xdl
818  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
819  {
820  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
821 
822  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
823  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
824  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
825  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
826 
828  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
829  }
830 
831  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
832  {
833  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
834 
835  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
836  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
837  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
838  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
839 
841  make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
842  }
843 
844  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
845  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
846  {
847  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
849  Number<NRepeat>{},
850  Number<MWaves>{},
851  Number<NWaves>{},
852  Number<MPerXDL>{},
853  Number<NPerXDL>{}));
854 
855  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
856  }
857 
858  // XDL output supporting C_xdl = A_xdl * B_xdl
859  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
860  {
861  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
863  Number<NRepeat>{},
864  Number<MWaves>{},
865  Number<NWaves>{},
866  Number<MPerXDL>{},
867  Number<NPerXDL>{}));
868 
869  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
870  }
871 
872  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
873  {
874  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
876  Number<MRepeat>{},
877  Number<NRepeat>{},
878  Number<MWaves>{},
879  Number<NWaves>{},
880  Number<MPerXDL>{},
881  Number<NPerXDL>{}));
882 
883  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
884  c_block_desc_g_m0_n0_m1_n1_m2_n2);
885  }
886 
887  template <typename CGridDesc_M_N>
888  __host__ __device__ static constexpr auto
889  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
890  {
891  const auto M = c_grid_desc_m_n.GetLength(I0);
892  const auto N = c_grid_desc_m_n.GetLength(I1);
893 
894  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
895  c_grid_desc_m_n,
896  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
897  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
900 
901  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
902  }
903 
904  template <typename CGridDesc_G_M_N>
905  __host__ __device__ static constexpr auto
906  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
907  {
908  const auto G = c_grid_desc_g_m_n.GetLength(I0);
909  const auto M = c_grid_desc_g_m_n.GetLength(I1);
910  const auto N = c_grid_desc_g_m_n.GetLength(I2);
911 
912  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
913  c_grid_desc_g_m_n,
915  make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
916  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
919 
920  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
921  c_grid_desc_g_m0_n0_m1_n1_m2_n2);
922  }
923 
924  static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
925  static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
926 
927  template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
928  __device__ void Run(const ABlockBuffer& a_block_buf,
929  const BBlockBuffer& b_block_buf,
930  CThreadBuffer& c_thread_buf) const
931  {
932  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
933  a_thread_desc_.GetElementSpaceSize());
934  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
935  b_thread_desc_.GetElementSpaceSize());
936 
937  static_for<0, KPerThread / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
938  static_for<0, MRepeat, 1>{}([&](auto m0) {
939  // read A
942  a_block_buf,
944  make_tuple(I0, I0, I0, I0),
945  a_thread_buf);
946 
947  static_for<0, NRepeat, 1>{}([&](auto n0) {
948  // read B
951  b_block_buf,
953  make_tuple(I0, I0, I0, I0),
954  b_thread_buf);
955  vector_type<FloatAB, KPack> a_thread_vec;
956  vector_type<FloatAB, KPack> b_thread_vec;
957 
958  static_for<0, KPack, 1>{}([&](auto i) {
959  a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
960  [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
961  b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
962  [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
963  });
964 
965  using mfma_input_type =
966  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
967 
968  constexpr index_t c_offset =
969  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
970 
971  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
972  b_thread_vec.template AsType<mfma_input_type>(),
973  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
974  });
975  });
976  });
977  }
978 
979  protected:
980  // A[M0, M1, M2, KPack]
981  static constexpr auto a_thread_desc_ =
983 
984  // B[N0, N1, N2, KPack]
985  static constexpr auto b_thread_desc_ =
987 
988  // C[M, N, NumRegXdlops]
990  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
991 
993  FloatAB,
994  decltype(a_block_desc_m0_m1_m2_k),
995  decltype(a_thread_desc_),
998  3,
999  A_K1,
1000  A_K1>;
1001 
1003  FloatAB,
1004  decltype(b_block_desc_n0_n1_n2_k),
1005  decltype(b_thread_desc_),
1008  3,
1009  B_K1,
1010  B_K1>;
1011 
1014 };
1015 
1016 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:207
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:266
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:606
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:297
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
static constexpr index_t KPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:56
static constexpr index_t A_K1
Definition: blockwise_gemm_smfmac_xdlops.hpp:61
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:426
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
Definition: blockwise_gemm_xdlops.hpp:165
static constexpr auto I2
Definition: blockwise_gemm_smfmac_xdlops.hpp:47
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:104
static constexpr index_t WaveSize
Definition: blockwise_gemm_smfmac_xdlops.hpp:52
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_xdlops.hpp:251
static constexpr index_t KPerThread
Definition: blockwise_gemm_smfmac_xdlops.hpp:67
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_xdlops.hpp:234
static constexpr index_t B_K1
Definition: blockwise_gemm_smfmac_xdlops.hpp:62
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:178
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:93
static constexpr index_t MPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:54
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_smfmac_xdlops.hpp:77
static constexpr auto b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_smfmac_xdlops.hpp:294
static constexpr index_t NPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:55
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:146
static constexpr auto I0
Definition: blockwise_gemm_smfmac_xdlops.hpp:45
ThreadwiseTensorSliceTransfer_v4< FloatA, FloatA, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition: blockwise_gemm_smfmac_xdlops.hpp:437
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:418
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_smfmac_xdlops.hpp:450
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:50
__host__ static constexpr __device__ auto MakeBBlockDescriptor_N0_N1_N2_K()
Definition: blockwise_gemm_xdlops.hpp:281
__host__ static constexpr __device__ auto MakeABlockDescriptor_M0_M1_M2_K()
Definition: blockwise_gemm_xdlops.hpp:269
static constexpr auto a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_smfmac_xdlops.hpp:293
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_smfmac_xdlops.hpp:449
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_xdlops.hpp:79
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:217
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:191
static constexpr index_t NWaves
Definition: blockwise_gemm_smfmac_xdlops.hpp:70
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_smfmac_xdlops.hpp:64
static constexpr index_t B_K0
Definition: blockwise_gemm_smfmac_xdlops.hpp:60
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:422
static constexpr index_t A_K0
Definition: blockwise_gemm_smfmac_xdlops.hpp:59
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_xdlops.hpp:297
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:204
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:117
ThreadwiseTensorSliceTransfer_v4< FloatB, ComputeTypeB, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition: blockwise_gemm_smfmac_xdlops.hpp:447
static constexpr auto I3
Definition: blockwise_gemm_smfmac_xdlops.hpp:48
static constexpr auto I1
Definition: blockwise_gemm_smfmac_xdlops.hpp:46
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_xdlops.hpp:81
static constexpr index_t MWaves
Definition: blockwise_gemm_smfmac_xdlops.hpp:69
Blockwise gemm.
Definition: blockwise_gemm_xdlops.hpp:674
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:818
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:831
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_xdlops.hpp:707
static constexpr index_t A_K0
Definition: blockwise_gemm_xdlops.hpp:684
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_xdlops.hpp:689
static constexpr index_t A_K1
Definition: blockwise_gemm_xdlops.hpp:686
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_xdlops.hpp:985
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_xdlops.hpp:709
static constexpr auto I1
Definition: blockwise_gemm_xdlops.hpp:676
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:859
static constexpr index_t NWaves
Definition: blockwise_gemm_xdlops.hpp:695
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_xdlops.hpp:925
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_xdlops.hpp:845
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_xdlops.hpp:928
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Definition: blockwise_gemm_xdlops.hpp:789
static constexpr index_t B_K0
Definition: blockwise_gemm_xdlops.hpp:685
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:745
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_xdlops.hpp:889
static constexpr auto I2
Definition: blockwise_gemm_xdlops.hpp:677
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:872
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition: blockwise_gemm_xdlops.hpp:787
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_xdlops.hpp:981
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_xdlops.hpp:989
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_xdlops.hpp:804
static constexpr auto I3
Definition: blockwise_gemm_xdlops.hpp:678
static constexpr index_t WaveSize
Definition: blockwise_gemm_xdlops.hpp:682
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_xdlops.hpp:906
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:721
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:732
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_xdlops.hpp:924
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_xdlops.hpp:680
static constexpr index_t MWaves
Definition: blockwise_gemm_xdlops.hpp:694
static constexpr index_t B_K1
Definition: blockwise_gemm_xdlops.hpp:687
static constexpr index_t KPerThread
Definition: blockwise_gemm_xdlops.hpp:692
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_xdlops.hpp:1012
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_xdlops.hpp:698
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:774
static constexpr auto I0
Definition: blockwise_gemm_xdlops.hpp:675
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_xdlops.hpp:1013
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
Definition: threadwise_tensor_slice_transfer.hpp:1260
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: xdlops_gemm.hpp:1669
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10