/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp Source File
blockwise_gemm_pipeline_xdlops_base.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 namespace ck {
13 
14 template <index_t BlockSize,
15  typename ADataType,
16  typename BDataType,
17  typename ComputeDataType,
18  typename AccDataType,
19  typename ATileDesc,
20  typename BTileDesc,
21  typename AMmaTileDesc,
22  typename BMmaTileDesc,
23  index_t ABlockTransferSrcScalarPerVector,
24  index_t BBlockTransferSrcScalarPerVector,
25  index_t MPerBlock,
26  index_t NPerBlock,
27  index_t KPerBlock,
28  index_t MPerXDL,
29  index_t NPerXDL,
30  index_t MRepeat,
31  index_t NRepeat,
32  index_t KPack,
33  bool TransposeC = false>
35 {
36  static constexpr auto I0 = Number<0>{};
37  static constexpr auto I1 = Number<1>{};
38  static constexpr auto I2 = Number<2>{};
39  static constexpr auto I3 = Number<3>{};
40 
42 
43  // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
44  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
45  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
46  static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
47 
48  static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
49  static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
50  static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
51  static constexpr index_t B_K1 =
52  BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {});
53 
54  static constexpr auto xdlops_gemm =
56 
59 
60  static constexpr index_t AMmaKStride = KPack;
61  static constexpr index_t BMmaKStride = KPack;
62 
63  static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
64  static constexpr index_t KRepeat = KPerThread / KPack;
65  static constexpr index_t KPerInnerLoop = KPack;
66 
67  static constexpr index_t KGroup = []() {
69  // On gfx950, we have mfma that required 32 f8 elements as input,
70  // splited into 2 groups of 16 f8 elements.
71  // the 2 groups is not contiguous in the B preshuffed layout.
72  // and we do not want it to be contiguous in the B preshuffled layout
73  // because a memory instruction can only read 16 f8 elements at a time.
74  return ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
75  (MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
76  ? 2
77  : 1;
78  else
79  return 1;
80  }();
81 
84  MPerBlock,
85  NPerBlock,
86  KPerBlock,
87  ABlockTransferSrcScalarPerVector,
88  BBlockTransferSrcScalarPerVector,
89  A_K1,
90  B_K1,
91  A_K1,
92  B_K1,
93  MRepeat,
94  NRepeat,
95  MPerXDL,
96  NPerXDL,
97  xdlops_gemm.KPerXdlops>;
98 
99 #if defined(__HIP_DEVICE_COMPILE__)
100  static_assert(KPerThread % KPack == 0,
101  "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
102 #endif
103 
105  AccDataType,
106  MRepeat * NRepeat,
107  xdlops_gemm.GetRegSizePerXdlops(),
108  true>
110 
111  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
112 
113  __device__ static auto GetWaveIdx()
114  {
115  const index_t thread_id = ThisThreadBlock::GetThreadId();
116 
117  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
121 
122  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
123  }
124 
125  __device__ static auto CalculateAThreadOriginDataIndex()
126  {
127  const auto wave_idx = GetWaveIdx();
128 
129  const auto waveId_m = wave_idx[I0];
130 
131  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
132 
133  return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
134  }
135 
136  __device__ static auto CalculateAThreadOriginDataIndex6D()
137  {
138  const auto wave_idx = GetWaveIdx();
139 
140  const auto waveId_m = wave_idx[I0];
141 
142  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
143 
144  return make_tuple(0, waveId_m, xdlops_a_idx[I1], 0, xdlops_a_idx[I0], 0);
145  }
146 
147  __device__ static auto CalculateBThreadOriginDataIndex()
148  {
149  const auto wave_idx = GetWaveIdx();
150 
151  const auto waveId_n = wave_idx[I1];
152 
153  const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
154 
155  return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
156  }
157 
158  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
159  __device__ static auto
161  {
162  const auto wave_idx = GetWaveIdx();
163 
164  const auto waveId_m = wave_idx[I0];
165  const auto waveId_n = wave_idx[I1];
166 
167  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
168 
169  constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
170  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
173 
174  constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
175  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
178 
179  const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
180  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
181  const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
182  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
183 
184  return make_tuple(c_thread_m, c_thread_n);
185  }
186 
187  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
188  __device__ static auto
190  {
191  const auto wave_idx = GetWaveIdx();
192 
193  const auto waveId_m = wave_idx[I0];
194  const auto waveId_n = wave_idx[I1];
195 
196  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
197 
198  return make_tuple(
199  m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
200  }
201 
203 
221  __host__ __device__
224  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
225  {
226 #if defined(__HIP_DEVICE_COMPILE__)
227  static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
228  "wrong! Desc should be known at compile-time");
229 
230  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
231  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
232 
233  static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
234  "wrong!");
235 #endif
236  }
237 
238  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
239  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
240  {
241  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
242 
243  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
244  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
245  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
246  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
247 
249  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
250  }
251 
252  // XDL output supporting C_xdl = A_xdl * B_xdl
253  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
254  {
255  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
256 
257  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
258  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
259  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
260  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
261 
263  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
264  }
265 
266  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
267  {
268  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
269 
270  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
271  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
272  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
273  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
274 
276  make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
277  }
278 
279  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
280  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
281  {
282  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
284  Number<NRepeat>{},
285  Number<MWaves>{},
286  Number<NWaves>{},
287  Number<MPerXDL>{},
288  Number<NPerXDL>{}));
289 
290  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
291  }
292 
293  // XDL output supporting C_xdl = A_xdl * B_xdl
294  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
295  {
296  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
298  Number<NRepeat>{},
299  Number<MWaves>{},
300  Number<NWaves>{},
301  Number<MPerXDL>{},
302  Number<NPerXDL>{}));
303 
304  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
305  }
306 
307  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
308  {
309  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
311  Number<MRepeat>{},
312  Number<NRepeat>{},
313  Number<MWaves>{},
314  Number<NWaves>{},
315  Number<MPerXDL>{},
316  Number<NPerXDL>{}));
317 
318  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
319  c_block_desc_g_m0_n0_m1_n1_m2_n2);
320  }
321 
322  template <typename CGridDesc_M_N>
323  __host__ __device__ static constexpr auto
324  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
325  {
326  const auto M = c_grid_desc_m_n.GetLength(I0);
327  const auto N = c_grid_desc_m_n.GetLength(I1);
328 
329  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
330  c_grid_desc_m_n,
331  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
332  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
335 
336  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
337  }
338 
339  template <typename CGridDesc_G_M_N>
340  __host__ __device__ static constexpr auto
341  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
342  {
343  const auto G = c_grid_desc_g_m_n.GetLength(I0);
344  const auto M = c_grid_desc_g_m_n.GetLength(I1);
345  const auto N = c_grid_desc_g_m_n.GetLength(I2);
346 
347  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
348  c_grid_desc_g_m_n,
350  make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
351  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
354 
355  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
356  c_grid_desc_g_m0_n0_m1_n1_m2_n2);
357  }
358  __host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
359  static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
360  static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
361 
362  protected:
363  // M1, N1 as double buffer index
364  // Read buffer + Compute buffer
365  // A[M0, M1, M2, KPack]
367  make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
368  make_tuple(
369  Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
370 
371  // B[N0, N1, N2, KPack]
373  make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
374  make_tuple(
375  Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
376 
377  // C[M, N, NumRegXdlops]
379  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
380 
383  decltype(a_block_desc_m0_m1_m2_k),
384  decltype(a_thread_desc_),
387  3,
388  A_K1,
389  A_K1>;
390 
393  decltype(b_block_desc_n0_n1_n2_k),
394  decltype(b_thread_desc_),
397  3,
398  B_K1,
399  B_K1>;
400 
403 };
404 
405 } // namespace ck
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, AccDataType, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:109
static constexpr index_t NWaves
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:45
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:222
static constexpr index_t MWaves
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:44
static constexpr index_t A_K0
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:48
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:378
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:54
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:147
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:360
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:113
static constexpr index_t KGroup
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:67
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr index_t AMmaKStride
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:60
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:402
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:125
static __device__ auto CalculateAThreadOriginDataIndex6D()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:136
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:46
static constexpr index_t B_K1
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:51
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:41
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:324
static constexpr index_t KPerInnerLoop
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:65
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:36
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:253
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:189
__host__ static constexpr __device__ auto GetCThreadDesc()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:358
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:366
static constexpr index_t KRepeat
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:359
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:307
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:372
static constexpr auto I2
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:38
static constexpr auto I3
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:39
static constexpr index_t A_K1
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:50
static constexpr index_t BMmaKStride
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:61
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:202
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:401
static constexpr index_t KPerThread
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:63
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:111
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
static constexpr index_t B_K0
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:49
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
Definition: xdlops_gemm.hpp:1821
Definition: amd_ck_fp8.hpp:36
Definition: integral_constant.hpp:20