/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp Source File
blockwise_gemm_pipeline_xdlops_base.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 namespace ck {
13 
14 template <index_t BlockSize,
15  typename ADataType,
16  typename BDataType,
17  typename ComputeDataType,
18  typename AccDataType,
19  typename ATileDesc,
20  typename BTileDesc,
21  typename AMmaTileDesc,
22  typename BMmaTileDesc,
23  index_t ABlockTransferSrcScalarPerVector,
24  index_t BBlockTransferSrcScalarPerVector,
25  index_t MPerBlock,
26  index_t NPerBlock,
27  index_t KPerBlock,
28  index_t MPerXDL,
29  index_t NPerXDL,
30  index_t MRepeat,
31  index_t NRepeat,
32  index_t KPack,
33  bool TransposeC = false>
35 {
36  static constexpr auto I0 = Number<0>{};
37  static constexpr auto I1 = Number<1>{};
38  static constexpr auto I2 = Number<2>{};
39  static constexpr auto I3 = Number<3>{};
40 
42 
43  // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
44  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
45  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
46  static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
47 
48  static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
49  static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
50  static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
51  static constexpr index_t B_K1 =
52  BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {});
53 
54  static constexpr auto xdlops_gemm =
56 
57  static constexpr index_t AMmaKStride = KPack;
58  static constexpr index_t BMmaKStride = KPack;
59 
60  static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
61  static constexpr index_t KRepeat = KPerThread / KPack;
62  static constexpr index_t KPerInnerLoop = KPack;
63 
64  static constexpr index_t KGroup = []() {
66  // On gfx950, we have mfma that required 32 f8 elements as input,
67  // splited into 2 groups of 16 f8 elements.
68  // the 2 groups is not contiguous in the B preshuffed layout.
69  // and we do not want it to be contiguous in the B preshuffled layout
70  // because a memory instruction can only read 16 f8 elements at a time.
71  return ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
72  (MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
73  ? 2
74  : 1;
75  else
76  return 1;
77  }();
78 
81  MPerBlock,
82  NPerBlock,
83  KPerBlock,
84  ABlockTransferSrcScalarPerVector,
85  BBlockTransferSrcScalarPerVector,
86  A_K1,
87  B_K1,
88  A_K1,
89  B_K1,
90  MRepeat,
91  NRepeat,
92  MPerXDL,
93  NPerXDL,
94  xdlops_gemm.KPerXdlops>;
95 
96 #if defined(__HIP_DEVICE_COMPILE__)
97  static_assert(KPerThread % KPack == 0,
98  "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
99 #endif
100 
102  AccDataType,
103  MRepeat * NRepeat,
104  xdlops_gemm.GetRegSizePerXdlops(),
105  true>
107 
108  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
109 
110  __device__ static auto GetWaveIdx()
111  {
112  const index_t thread_id = ThisThreadBlock::GetThreadId();
113 
114  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
118 
119  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
120  }
121 
122  __device__ static auto CalculateAThreadOriginDataIndex()
123  {
124  const auto wave_idx = GetWaveIdx();
125 
126  const auto waveId_m = wave_idx[I0];
127 
128  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
129 
130  return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
131  }
132 
133  __device__ static auto CalculateAThreadOriginDataIndex6D()
134  {
135  const auto wave_idx = GetWaveIdx();
136 
137  const auto waveId_m = wave_idx[I0];
138 
139  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
140 
141  return make_tuple(0, waveId_m, xdlops_a_idx[I1], 0, xdlops_a_idx[I0], 0);
142  }
143 
144  __device__ static auto CalculateBThreadOriginDataIndex()
145  {
146  const auto wave_idx = GetWaveIdx();
147 
148  const auto waveId_n = wave_idx[I1];
149 
150  const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
151 
152  return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
153  }
154 
155  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
156  __device__ static auto
158  {
159  const auto wave_idx = GetWaveIdx();
160 
161  const auto waveId_m = wave_idx[I0];
162  const auto waveId_n = wave_idx[I1];
163 
164  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
165 
166  constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
167  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
170 
171  constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
172  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
175 
176  const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
177  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
178  const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
179  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
180 
181  return make_tuple(c_thread_m, c_thread_n);
182  }
183 
184  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
185  __device__ static auto
187  {
188  const auto wave_idx = GetWaveIdx();
189 
190  const auto waveId_m = wave_idx[I0];
191  const auto waveId_n = wave_idx[I1];
192 
193  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
194 
195  return make_tuple(
196  m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
197  }
198 
200 
218  __host__ __device__
221  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
222  {
223 #if defined(__HIP_DEVICE_COMPILE__)
224  static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
225  "wrong! Desc should be known at compile-time");
226 
227  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
228  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
229 
230  static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
231  "wrong!");
232 #endif
233  }
234 
235  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
236  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
237  {
238  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
239 
240  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
241  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
242  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
243  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
244 
246  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
247  }
248 
249  // XDL output supporting C_xdl = A_xdl * B_xdl
250  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
251  {
252  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
253 
254  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
255  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
256  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
257  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
258 
260  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
261  }
262 
263  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
264  {
265  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
266 
267  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
268  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
269  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
270  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
271 
273  make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
274  }
275 
276  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
277  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
278  {
279  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
281  Number<NRepeat>{},
282  Number<MWaves>{},
283  Number<NWaves>{},
284  Number<MPerXDL>{},
285  Number<NPerXDL>{}));
286 
287  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
288  }
289 
290  // XDL output supporting C_xdl = A_xdl * B_xdl
291  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
292  {
293  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
295  Number<NRepeat>{},
296  Number<MWaves>{},
297  Number<NWaves>{},
298  Number<MPerXDL>{},
299  Number<NPerXDL>{}));
300 
301  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
302  }
303 
304  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
305  {
306  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
308  Number<MRepeat>{},
309  Number<NRepeat>{},
310  Number<MWaves>{},
311  Number<NWaves>{},
312  Number<MPerXDL>{},
313  Number<NPerXDL>{}));
314 
315  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
316  c_block_desc_g_m0_n0_m1_n1_m2_n2);
317  }
318 
319  template <typename CGridDesc_M_N>
320  __host__ __device__ static constexpr auto
321  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
322  {
323  const auto M = c_grid_desc_m_n.GetLength(I0);
324  const auto N = c_grid_desc_m_n.GetLength(I1);
325 
326  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
327  c_grid_desc_m_n,
328  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
329  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
332 
333  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
334  }
335 
336  template <typename CGridDesc_G_M_N>
337  __host__ __device__ static constexpr auto
338  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
339  {
340  const auto G = c_grid_desc_g_m_n.GetLength(I0);
341  const auto M = c_grid_desc_g_m_n.GetLength(I1);
342  const auto N = c_grid_desc_g_m_n.GetLength(I2);
343 
344  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
345  c_grid_desc_g_m_n,
347  make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
348  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
351 
352  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
353  c_grid_desc_g_m0_n0_m1_n1_m2_n2);
354  }
355  __host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
356  static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
357  static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
358 
359  protected:
360  // M1, N1 as double buffer index
361  // Read buffer + Compute buffer
362  // A[M0, M1, M2, KPack]
364  make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
365  make_tuple(
366  Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
367 
368  // B[N0, N1, N2, KPack]
370  make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
371  make_tuple(
372  Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
373 
374  // C[M, N, NumRegXdlops]
376  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
377 
379  ComputeDataType,
380  decltype(a_block_desc_m0_m1_m2_k),
381  decltype(a_thread_desc_),
384  3,
385  A_K1,
386  A_K1>;
387 
389  ComputeDataType,
390  decltype(b_block_desc_n0_n1_n2_k),
391  decltype(b_thread_desc_),
394  3,
395  B_K1,
396  B_K1>;
397 
400 };
401 
402 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, AccDataType, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:106
static constexpr index_t NWaves
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:45
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:219
static constexpr index_t MWaves
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:44
static constexpr index_t A_K0
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:48
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:277
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:291
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:375
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:54
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:144
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:357
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:110
static constexpr index_t KGroup
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:236
static constexpr index_t AMmaKStride
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:57
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:399
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:122
static __device__ auto CalculateAThreadOriginDataIndex6D()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:133
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:46
static constexpr index_t B_K1
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:51
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:41
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:338
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:321
static constexpr index_t KPerInnerLoop
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:62
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:36
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:250
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:157
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:186
__host__ static constexpr __device__ auto GetCThreadDesc()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:355
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:363
static constexpr index_t KRepeat
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:61
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:356
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:304
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:369
static constexpr auto I2
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:38
static constexpr auto I3
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:39
static constexpr index_t A_K1
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:50
static constexpr index_t BMmaKStride
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:199
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:398
static constexpr index_t KPerThread
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:60
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:263
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:108
static constexpr index_t B_K0
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:49
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
Definition: xdlops_gemm.hpp:1711
Definition: integral_constant.hpp:20