/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp Source File
blockwise_gemm_smfmac_xdlops.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 namespace ck {
13 
14 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
15 __host__ __device__ static constexpr auto
16 MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
17 {
18  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
19  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
20 
22  TileDesc_K0_MN_K1{},
23  make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
25  make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
26  make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
27  make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
28 }
29 
30 template <index_t BlockSize,
31  typename FloatA,
32  typename FloatB,
33  typename FloatAcc,
34  typename AK0MK1BlockDesc,
35  typename BK0NK1BlockDesc,
36  index_t MPerXDL,
37  index_t NPerXDL,
38  index_t MRepeat,
39  index_t NRepeat,
40  index_t KPack,
41  typename ComputeTypeA = FloatA,
42  typename ComputeTypeB = FloatB>
44 {
45  static constexpr auto I0 = Number<0>{};
46  static constexpr auto I1 = Number<1>{};
47  static constexpr auto I2 = Number<2>{};
48  static constexpr auto I3 = Number<3>{};
49 
51 
52  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
53  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
54  static_assert(MWaves > 0);
55  static_assert(NWaves > 0);
56  static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
57 
58  static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
59  static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
60  static constexpr index_t KPerBlock =
61  BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
62 
63  static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
64  static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
65  static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
66  static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
67 
68  static constexpr auto xdlops_gemm =
70 
71  static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
72 
74  FloatAcc,
75  MRepeat * NRepeat,
76  xdlops_gemm.GetRegSizePerXdlops(),
77  true>
79 
80  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
81 
82  __device__ static auto GetWaveIdx()
83  {
84  const index_t thread_id = ThisThreadBlock::GetThreadId();
85 
86  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
90 
91  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
92  }
93 
94  __device__ static auto CalculateAThreadOriginDataIndex()
95  {
96  const auto wave_idx = GetWaveIdx();
97  const auto waveId_m = wave_idx[I0];
98  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
99 
100  return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
101  }
102 
103  __device__ static auto CalculateBThreadOriginDataIndex()
104  {
105  const auto wave_idx = GetWaveIdx();
106  const auto waveId_n = wave_idx[I1];
107  const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
108 
109  return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
110  }
111 
112  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
113  __device__ static auto
115  {
116  const auto wave_idx = GetWaveIdx();
117  const auto waveId_m = wave_idx[I0];
118  const auto waveId_n = wave_idx[I1];
119 
120  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
121 
122  constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
123  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
126 
127  constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
128  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
131 
132  const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
133  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
134  const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
135  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
136 
137  return make_tuple(c_thread_m, c_thread_n);
138  }
139 
140  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
141  __device__ static auto
143  {
144  const auto wave_idx = GetWaveIdx();
145  const auto waveId_m = wave_idx[I0];
146  const auto waveId_n = wave_idx[I1];
147 
148  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
149 
150  return make_tuple(Number<m0>{},
151  Number<n0>{},
152  waveId_m,
153  waveId_n,
154  blk_idx[I0],
155  blk_idx[I1],
156  blk_idx[I2],
157  blk_idx[I3]);
158  }
159 
161  {
162 #if defined(__HIP_DEVICE_COMPILE__)
163  static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
164  BK0NK1BlockDesc::IsKnownAtCompileTime(),
165  "wrong! Desc should be known at compile-time");
166 
167  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
168  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
169 
170  static_assert(MPerBlock % (MPerXDL * MRepeat) == 0,
171  "MPerBlock must be divisible by MPerXDL * MRepeat");
172  static_assert(NPerBlock % (NPerXDL * NRepeat) == 0,
173  "NPerBlock must be divisible by NPerXDL * NRepeat");
174 
175  static_assert(
176  KPack % (16 * sizeof(ComputeTypeA)) == 0,
177  "KPack must be divisbile by number of elements processed in single smfmac instruction");
178 #endif
179  }
180 
181  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
182  {
183  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
184 
185  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
186  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
187  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
188  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
189 
191  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
192  }
193 
194  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
195  {
196  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
197 
198  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
199  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
200  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
201  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
202 
204  make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
205  }
206 
207  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
208  {
209  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
211  Number<NRepeat>{},
212  Number<MWaves>{},
213  Number<NWaves>{},
214  Number<MPerXDL>{},
215  Number<NPerXDL>{}));
216 
217  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
218  }
219 
220  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
221  {
222  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
224  Number<MRepeat>{},
225  Number<NRepeat>{},
226  Number<MWaves>{},
227  Number<NWaves>{},
228  Number<MPerXDL>{},
229  Number<NPerXDL>{}));
230 
231  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
232  c_block_desc_g_m0_n0_m1_n1_m2_n2);
233  }
234 
235  template <typename CGridDesc_M_N>
236  __host__ __device__ static constexpr auto
237  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
238  {
239  const auto M = c_grid_desc_m_n.GetLength(I0);
240  const auto N = c_grid_desc_m_n.GetLength(I1);
241 
242  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
243  c_grid_desc_m_n,
244  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
245  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
248 
249  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
250  }
251 
252  template <typename CGridDesc_G_M_N>
253  __host__ __device__ static constexpr auto
254  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
255  {
256  const auto G = c_grid_desc_g_m_n.GetLength(I0);
257  const auto M = c_grid_desc_g_m_n.GetLength(I1);
258  const auto N = c_grid_desc_g_m_n.GetLength(I2);
259 
260  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
261  c_grid_desc_g_m_n,
263  make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
264  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
267 
268  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
269  c_grid_desc_g_m0_n0_m1_n1_m2_n2);
270  }
271 
272  __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
273  {
275  AK0MK1BlockDesc{},
276  make_tuple(
282  }
283 
284  __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
285  {
287  BK0NK1BlockDesc{},
288  make_tuple(
294  }
295 
298 
299  // Prepares data in a_thread_buf by squeezing values by ommiting zeros to adjust it to 2:4
300  // structural sparsity. The indexes of non-zero elements are stored in idx_buf and used later in
301  // smfmac instruction
302  template <typename AThreadBuf, typename IdxBuf, int32_t num_elems>
303  __device__ void SetIdxSqueezeA(AThreadBuf& a_thread_buf, IdxBuf& idx_buf)
304  {
305  static constexpr int32_t bit_clear_masks[4] = {0b11, 0b1100, 0b110000, 0b11000000};
306  static constexpr int32_t processed_elems = 16 / sizeof(ComputeTypeA);
307 
309  constexpr int idx_reg_num = i / (16 * sizeof(ComputeTypeA));
310  constexpr int idx_reg_part = (i % 32) / processed_elems;
311 
313  static_for<0, processed_elems, 1>{}([&](auto j) {
314  a_thread_vec.template AsType<ComputeTypeA>()(j) = a_thread_buf
315  [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i + j))>{}];
316  });
317 
318  uint8_t idx = 0b11101110; // set to last 2 elems for both 4-elems subgroups by default
319  for(int j = 0; j < processed_elems; j += 4)
320  {
321  int32_t a_pos = idx_reg_part * processed_elems + j;
322  int32_t nonzero_pos = 0;
323  ComputeTypeA nonzero_elems[2] = {a_thread_vec[j + 2], a_thread_vec[j + 3]};
324  for(int k = 0; k < 3; k += 1)
325  {
326  if(a_thread_vec[j + k] != 0.0f)
327  {
328  nonzero_elems[nonzero_pos] = a_thread_vec[j + k];
329  idx &= ~bit_clear_masks[j / 2 + nonzero_pos];
330  idx |= k << 2 * (j / 2 + nonzero_pos);
331  ++nonzero_pos;
332  }
333  }
334  a_thread_vec[j / 2] = nonzero_elems[0];
335  a_thread_vec[j / 2 + 1] = nonzero_elems[1];
336  }
337  IdxBuf[idx_reg_num].AsType<int8x4_t>()[Number<idx_reg_part>{}] = idx;
338 
339  static_for<0, processed_elems / 2, 1>{}([&](auto j) {
340  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
341  make_tuple(0, 0, 0, i / 2 + j))>{}] = a_thread_vec[j];
342  });
343  });
344  }
345 
346  template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
347  __device__ void Run(const ABlockBuffer& a_block_buf,
348  const BBlockBuffer& b_block_buf,
349  CThreadBuffer& c_thread_buf) const
350  {
351  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
352  a_thread_desc_.GetElementSpaceSize());
353  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
354  b_thread_desc_.GetElementSpaceSize());
355  static constexpr int32_t elems_per_idx = 16 * sizeof(ComputeTypeA);
356  auto idx_buf = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
357  (a_thread_desc_.GetElementSpaceSize() + elems_per_idx - 1) / elems_per_idx);
358 
359  static_for<0, MRepeat, 1>{}([&](auto m0) {
360  // read A
362  make_tuple(m0, I0, I0, I0),
363  a_block_buf,
365  make_tuple(I0, I0, I0, I0),
366  a_thread_buf);
367 
368  SetIdxSqueezeA(a_thread_buf, idx_buf, a_thread_desc_.GetElementSpaceSize());
369 
370  static_for<0, NRepeat, 1>{}([&](auto n0) {
371  // read B
373  make_tuple(n0, I0, I0, I0),
374  b_block_buf,
376  make_tuple(I0, I0, I0, I0),
377  b_thread_buf);
378 
379  static_for<0, KPerThread, KPack>{}([&](auto k) {
380  // a_thread_vec is smaller because it's structurally sparse 2:4
381  vector_type<ComputeTypeA, KPack / 2> a_thread_vec;
383  vector_type<int32_t, KPack / elems_per_idx> idx_vec;
384 
385  static_for<0, KPack / 2, 1>{}([&](auto i) {
386  a_thread_vec.template AsType<ComputeTypeA>()(i) =
387  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
388  make_tuple(0, 0, 0, k / 2 + i))>{}];
389  });
390 
391  static_for<0, KPack, 1>{}([&](auto i) {
392  b_thread_vec.template AsType<ComputeTypeB>()(2 * i) = b_thread_buf
393  [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
394  });
395 
396  static_for<0, KPack / elems_per_idx, 1>{}([&](auto i) {
397  idx_vec.template AsType<int32_t>()(i) = idx_buf[k / elems_per_idx + i];
398  });
399 
400  // A is smaller because it's structurally sparse 2:4
401  using mfma_input_type_a =
402  typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops / 2>::type;
403  using mfma_input_type_b =
404  typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
405  using mfma_input_type_idx = typename vector_type<int32_t, 1>::type;
406 
407  constexpr index_t c_offset =
408  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
409 
410  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
411  b_thread_vec.template AsType<mfma_input_type_b>(),
412  idx_vec.template AsType<mfma_input_type_idx>(),
413  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
414  });
415  });
416  });
417  }
418 
419  protected:
420  // A[M0, M1, M2, KPerThread]
421  static constexpr auto a_thread_desc_ =
423 
424  // B[N0, N1, N2, KPerThread]
425  static constexpr auto b_thread_desc_ =
427 
428  // C[M, N, NumRegXdlops]
430  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
431 
433  ComputeTypeA,
434  decltype(a_block_desc_m0_m1_m2_k),
435  decltype(a_thread_desc_),
438  3,
439  A_K1,
440  A_K1>;
441 
443  ComputeTypeB,
444  decltype(b_block_desc_n0_n1_n2_k),
445  decltype(b_thread_desc_),
448  3,
449  B_K1,
450  B_K1>;
451 
454 };
455 
456 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2163
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
signed int int32_t
Definition: stdint.h:123
unsigned char uint8_t
Definition: stdint.h:124
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
static constexpr index_t KPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:60
static constexpr index_t A_K1
Definition: blockwise_gemm_smfmac_xdlops.hpp:65
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:429
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
Definition: blockwise_gemm_smfmac_xdlops.hpp:160
static constexpr auto I2
Definition: blockwise_gemm_smfmac_xdlops.hpp:47
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_smfmac_xdlops.hpp:103
static constexpr index_t WaveSize
Definition: blockwise_gemm_smfmac_xdlops.hpp:56
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_smfmac_xdlops.hpp:254
static constexpr index_t KPerThread
Definition: blockwise_gemm_smfmac_xdlops.hpp:71
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_smfmac_xdlops.hpp:237
static constexpr index_t B_K1
Definition: blockwise_gemm_smfmac_xdlops.hpp:66
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_smfmac_xdlops.hpp:181
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_smfmac_xdlops.hpp:94
static constexpr index_t MPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:58
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_smfmac_xdlops.hpp:78
static constexpr auto b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_smfmac_xdlops.hpp:297
static constexpr index_t NPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:59
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_smfmac_xdlops.hpp:142
static constexpr auto I0
Definition: blockwise_gemm_smfmac_xdlops.hpp:45
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:421
__device__ void SetIdxSqueezeA(AThreadBuf &a_thread_buf, IdxBuf &idx_buf)
Definition: blockwise_gemm_smfmac_xdlops.hpp:303
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_smfmac_xdlops.hpp:453
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:50
__host__ static constexpr __device__ auto MakeBBlockDescriptor_N0_N1_N2_K()
Definition: blockwise_gemm_smfmac_xdlops.hpp:284
__host__ static constexpr __device__ auto MakeABlockDescriptor_M0_M1_M2_K()
Definition: blockwise_gemm_smfmac_xdlops.hpp:272
static constexpr auto a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_smfmac_xdlops.hpp:296
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_smfmac_xdlops.hpp:452
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_smfmac_xdlops.hpp:80
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_smfmac_xdlops.hpp:220
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_smfmac_xdlops.hpp:194
static constexpr index_t NWaves
Definition: blockwise_gemm_smfmac_xdlops.hpp:53
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_smfmac_xdlops.hpp:68
static constexpr index_t B_K0
Definition: blockwise_gemm_smfmac_xdlops.hpp:64
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:425
static constexpr index_t A_K0
Definition: blockwise_gemm_smfmac_xdlops.hpp:63
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_smfmac_xdlops.hpp:347
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_smfmac_xdlops.hpp:207
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_smfmac_xdlops.hpp:114
static constexpr auto I3
Definition: blockwise_gemm_smfmac_xdlops.hpp:48
static constexpr auto I1
Definition: blockwise_gemm_smfmac_xdlops.hpp:46
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_smfmac_xdlops.hpp:82
static constexpr index_t MWaves
Definition: blockwise_gemm_smfmac_xdlops.hpp:52
Definition: sequence.hpp:43
Definition: smfmac_xdlops_gemm.hpp:215
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
Definition: threadwise_tensor_slice_transfer.hpp:1260
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10