/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp Source File
gridwise_gemm_xdlops_v3r3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm,
21  typename FloatAB,
22  typename FloatC,
23  typename AGridDesc_K0_M_K1,
24  typename BGridDesc_K0_N_K1,
25  typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
26  typename C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
27  typename C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
28  typename AElementwiseOperation,
29  typename BElementwiseOperation,
30  typename CElementwiseOperation,
31  typename Block2CTileMap,
32  bool HasMainKBlockLoop>
33 __global__ void
34 #if CK_USE_LAUNCH_BOUNDS
36 #endif
38  const FloatAB* __restrict__ p_a_grid,
39  const FloatAB* __restrict__ p_b_grid,
40  FloatC* __restrict__ p_c_grid,
41  const FloatC* __restrict__ p_c0_grid,
42  const FloatC* __restrict__ p_c1_grid,
43  const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
44  const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
45  const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
46  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
47  const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
48  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
49  const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
50  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
51  const AElementwiseOperation a_element_op,
52  const BElementwiseOperation b_element_op,
53  const CElementwiseOperation c_element_op,
54  const Block2CTileMap block_2_ctile_map)
55 {
56 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
57  defined(__gfx12__)
58  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
59  {
60  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
61 
62  GridwiseGemm::template Run<HasMainKBlockLoop>(
63  p_a_grid,
64  p_b_grid,
65  p_c_grid,
66  p_c0_grid,
67  p_c1_grid,
68  p_shared,
69  a_grid_desc_k0_m_k1,
70  b_grid_desc_k0_n_k1,
71  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
72  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
73  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
74  a_element_op,
75  b_element_op,
76  c_element_op,
77  block_2_ctile_map);
78  }
79 #else
80  ignore = p_a_grid;
81  ignore = p_b_grid;
82  ignore = p_c_grid;
83  ignore = p_c0_grid;
84  ignore = p_c1_grid;
85  ignore = a_grid_desc_k0_m_k1;
86  ignore = b_grid_desc_k0_n_k1;
87  ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
88  ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
89  ignore = c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
90  ignore = a_element_op;
91  ignore = b_element_op;
92  ignore = c_element_op;
93  ignore = block_2_ctile_map;
94 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
95 }
96 
97 template <
98  index_t BlockSize,
99  typename FloatAB,
100  typename FloatAcc,
101  typename FloatC,
102  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
103  typename AGridDesc_K0_M_K1,
104  typename BGridDesc_K0_N_K1,
105  typename CGridDesc_M_N,
106  typename C0GridDesc_M_N,
107  typename C1GridDesc_M_N,
108  typename AElementwiseOperation,
109  typename BElementwiseOperation,
110  typename CElementwiseOperation,
111  index_t MPerBlock,
112  index_t NPerBlock,
113  index_t K0PerBlock,
114  index_t MPerXdl,
115  index_t NPerXdl,
116  index_t K1Value,
117  index_t MXdlPerWave,
118  index_t NXdlPerWave,
119  typename ABlockTransferThreadClusterLengths_K0_M_K1,
120  typename ABlockTransferThreadClusterArrangeOrder,
121  typename ABlockTransferSrcAccessOrder,
122  index_t ABlockTransferSrcVectorDim,
123  index_t ABlockTransferSrcScalarPerVector,
124  index_t ABlockTransferDstScalarPerVector_K1,
125  bool AThreadTransferSrcResetCoordinateAfterRun,
126  bool ABlockLdsExtraM,
127  typename BBlockTransferThreadClusterLengths_K0_N_K1,
128  typename BBlockTransferThreadClusterArrangeOrder,
129  typename BBlockTransferSrcAccessOrder,
130  index_t BBlockTransferSrcVectorDim,
131  index_t BBlockTransferSrcScalarPerVector,
132  index_t BBlockTransferDstScalarPerVector_K1,
133  bool BThreadTransferSrcResetCoordinateAfterRun,
134  bool BBlockLdsExtraN,
135  index_t CShuffleMXdlPerWavePerShuffle,
136  index_t CShuffleNXdlPerWavePerShuffle,
137  typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
138  index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
139  index_t NumGemmKPrefetchStage = 1,
140  PipelineVersion PipelineVer = PipelineVersion::v1>
142 {
143  static constexpr auto I0 = Number<0>{};
144  static constexpr auto I1 = Number<1>{};
145  static constexpr auto I2 = Number<2>{};
146  static constexpr auto I3 = Number<3>{};
147  static constexpr auto I4 = Number<4>{};
148  static constexpr auto I5 = Number<5>{};
149  static constexpr auto I6 = Number<6>{};
150  static constexpr auto I7 = Number<7>{};
151 
152  // K1 should be Number<...>
153  static constexpr auto K1 = Number<K1Value>{};
154 
156 
158  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
159 
160  __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
161  {
162  constexpr auto max_lds_align = K1;
163 
164  // A matrix in LDS memory, dst of blockwise copy
165  constexpr auto a_block_desc_k0_m_k1 = [&]() {
166  if constexpr(ABlockLdsExtraM)
167  {
171  }
172  else
173  {
175  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
176  }
177  }();
178 
179  return a_block_desc_k0_m_k1;
180  }
181 
182  __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
183  {
184  constexpr auto max_lds_align = K1;
185 
186  // B matrix in LDS memory, dst of blockwise copy
187  constexpr auto b_block_desc_k0_n_k1 = [&]() {
188  if constexpr(BBlockLdsExtraN)
189  {
193  }
194  else
195  {
197  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
198  }
199  }();
200 
201  return b_block_desc_k0_n_k1;
202  }
203 
204  __host__ __device__ static constexpr auto
206  {
207  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
208  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
209 
210  constexpr auto
211  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
213  make_tuple(I1,
216  I1,
219 
220  return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
221  }
222 
223  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
224  {
225  // LDS allocation for A and B: be careful of alignment
226  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
227 
228  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
229 
230  constexpr auto max_lds_align = K1;
231 
232  constexpr auto a_block_space_size_aligned =
233  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
234 
235  constexpr auto b_block_space_size_aligned =
236  math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
237 
238  // LDS allocation for C shuffle in LDS
239  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
241 
242  constexpr auto c_block_size =
243  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
244  .GetElementSpaceSize();
245 
246  return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
247  sizeof(FloatAB),
248  c_block_size * sizeof(FloatC));
249  }
250 
251  template <
252  InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
253  __device__ static bool constexpr IsValidCompilationParameter()
254  {
255  return ck::tensor_operation::device::IsValidGemmCompilationParameter<
256  BlockSize,
257  MPerBlock,
258  NPerBlock,
259  MPerXdl,
260  NPerXdl,
261  MXdlPerWave,
262  NXdlPerWave,
263  FloatC,
264  CGlobalMemoryDataOperation>();
265  }
266 
267  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
268  template <typename Block2CTileMap>
269  __host__ __device__ static constexpr bool
270  CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
271  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
272  const CGridDesc_M_N& c_grid_desc_m_n,
273  const Block2CTileMap& block_2_ctile_map)
274  {
275  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
276  "wrong! K1 need to be known at compile-time");
277 
278  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
279  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
280  "Invalid tuning param!");
281 
282  const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
283  const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
284  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
285 
286  if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
287  K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
288  K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
289  return false;
290 
291  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
292  return false;
293 
294  // check gridwise gemm pipeline
295  const auto num_k_loop = K0 / K0PerBlock;
296 
297  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
298  {
299  return false;
300  }
301 
302  if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
303  {
304  return false;
305  }
306 
307  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
308  return true;
309  }
310 
311  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
312  {
313  const index_t num_loop = K / (K0PerBlock * K1);
314 
315  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
316  }
317 
318  template <typename CGridDesc_M_N_>
319  __host__ __device__ static constexpr auto
321  const CGridDesc_M_N_& c_grid_desc_m_n)
322  {
323  const auto M = c_grid_desc_m_n.GetLength(I0);
324  const auto N = c_grid_desc_m_n.GetLength(I1);
325 
326  const auto MBlock = M / MPerBlock;
327  const auto NBlock = N / NPerBlock;
328 
329  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
330  constexpr index_t NWave =
331  NXdlPerWave * NPerXdl == 0 ? 1 : NPerBlock / (NXdlPerWave * NPerXdl);
332 
333  const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
335  c_grid_desc_m_n,
342 
343  return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
344  }
345 
346  // return block_id to C matrix tile idx (m0, n0) mapping
347  __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
348  const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
349  {
351  c_grid_desc_m_n);
352  }
356  CGridDesc_M_N{}))>;
357 
361  C0GridDesc_M_N{}))>;
362 
366  C1GridDesc_M_N{}))>;
367 
369  remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
370 
371  template <bool HasMainKBlockLoop, typename Block2CTileMap>
372  __device__ static void
373  Run(const FloatAB* __restrict__ p_a_grid,
374  const FloatAB* __restrict__ p_b_grid,
375  FloatC* __restrict__ p_c_grid,
376  const FloatC* __restrict__ p_c0_grid,
377  const FloatC* __restrict__ p_c1_grid,
378  void* __restrict__ p_shared,
379  const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
380  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
382  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
384  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
386  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
387  const AElementwiseOperation& a_element_op,
388  const BElementwiseOperation& b_element_op,
389  const CElementwiseOperation& c_element_op,
390  const Block2CTileMap& block_2_ctile_map)
391  {
392  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
393  p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
394  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
395  p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
396  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
397  p_c_grid,
398  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
399  .GetElementSpaceSize());
400  auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
401  p_c0_grid,
402  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
403  .GetElementSpaceSize());
404  auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
405  p_c1_grid,
406  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
407  .GetElementSpaceSize());
408 
409  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
410 
411  // divide block work by [M, N]
412  const auto block_work_idx =
413  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
414 
415  if(!block_2_ctile_map.ValidCTileIndex(
416  block_work_idx,
417  make_tuple(
418  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
419  .GetLength(I0),
420  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
421  .GetLength(I3))))
422  {
423  return;
424  }
425 
426  // HACK: this force m/n_block_data_idx_on_grid into SGPR
427  const index_t m_block_data_idx_on_grid =
428  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
429 
430  const index_t n_block_data_idx_on_grid =
431  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
432 
433  // lds max alignment
434  constexpr auto max_lds_align = K1;
435 
436  // A matrix in LDS memory, dst of blockwise copy
437  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
438 
439  // B matrix in LDS memory, dst of blockwise copy
440  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
441 
442  // A matrix blockwise copy
443  auto a_blockwise_copy =
445  AElementwiseOperation,
449  ABlockTransferThreadClusterLengths_K0_M_K1,
450  ABlockTransferThreadClusterArrangeOrder,
451  FloatAB,
452  FloatAB,
453  decltype(a_grid_desc_k0_m_k1),
454  decltype(a_block_desc_k0_m_k1),
455  ABlockTransferSrcAccessOrder,
457  ABlockTransferSrcVectorDim,
458  2,
459  ABlockTransferSrcScalarPerVector,
460  ABlockTransferDstScalarPerVector_K1,
461  1,
462  1,
463  AThreadTransferSrcResetCoordinateAfterRun,
464  true>(
465  a_grid_desc_k0_m_k1,
466  make_multi_index(0, m_block_data_idx_on_grid, 0),
467  a_element_op,
468  a_block_desc_k0_m_k1,
469  make_multi_index(0, 0, 0),
471 
472  // B matrix blockwise copy
473  auto b_blockwise_copy =
475  BElementwiseOperation,
479  BBlockTransferThreadClusterLengths_K0_N_K1,
480  BBlockTransferThreadClusterArrangeOrder,
481  FloatAB,
482  FloatAB,
483  decltype(b_grid_desc_k0_n_k1),
484  decltype(b_block_desc_k0_n_k1),
485  BBlockTransferSrcAccessOrder,
487  BBlockTransferSrcVectorDim,
488  2,
489  BBlockTransferSrcScalarPerVector,
490  BBlockTransferDstScalarPerVector_K1,
491  1,
492  1,
493  BThreadTransferSrcResetCoordinateAfterRun,
494  true>(
495  b_grid_desc_k0_n_k1,
496  make_multi_index(0, n_block_data_idx_on_grid, 0),
497  b_element_op,
498  b_block_desc_k0_n_k1,
499  make_multi_index(0, 0, 0),
501 
502  // GEMM definition
503  // c_mtx += transpose(a_mtx) * b_mtx
504  // a_mtx[K0PerBlock, MPerBlock] is in LDS
505  // b_mtx[K0PerBlock, NPerBlock] is in LDS
506  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
507  // register
508  // sanity check
509 
510  auto blockwise_gemm =
512  FloatAB,
513  FloatAB,
514  FloatAcc,
515  decltype(a_block_desc_k0_m_k1),
516  decltype(b_block_desc_k0_n_k1),
517  MPerXdl,
518  NPerXdl,
519  MXdlPerWave,
520  NXdlPerWave,
521  K1>{};
522 
523  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
524 
525  // LDS allocation for A and B: be careful of alignment
526  constexpr auto a_block_space_size_aligned =
527  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
528 
529  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
530  static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
531 
532  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
533  static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
534  b_block_desc_k0_n_k1.GetElementSpaceSize());
535 
536  constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
537  constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
538 
539  // gridwise GEMM pipeline
540  const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
541 
542  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
543  a_block_desc_k0_m_k1,
544  a_blockwise_copy,
545  a_grid_buf,
546  a_block_buf,
547  a_block_slice_copy_step,
548  b_grid_desc_k0_n_k1,
549  b_block_desc_k0_n_k1,
550  b_blockwise_copy,
551  b_grid_buf,
552  b_block_buf,
553  b_block_slice_copy_step,
554  blockwise_gemm,
555  c_thread_buf,
556  K0BlockMainLoop);
557 
558  // shuffle C and write out
559  {
560  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
561  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
562  "wrong!");
563 
564  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
565  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
566 
567  // TODO: hacky, fix it!
568  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
569  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
570 
571  // TODO: hacky, fix it!
572  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
573  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
574  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
575 
576  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
577  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
578  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
579  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
580  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
581  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
582  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
583  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
584 
585  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
587 
588  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
589  static_cast<FloatC*>(p_shared),
590  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
591  .GetElementSpaceSize());
592 
593  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
594  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
595  make_tuple(make_freeze_transform(I0), // freeze mblock
597  Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per
598  // shuffle
600  make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
601  make_freeze_transform(I0), // freeze nblock
603  Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per
604  // shuffle
606  make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
608  Sequence<1>{},
609  Sequence<2>{},
610  Sequence<3>{},
611  Sequence<4>{},
612  Sequence<5>{}),
614  Sequence<0>{},
616  Sequence<>{},
617  Sequence<1>{},
618  Sequence<3, 7>{})
619 
620  );
621 
622  // calculate origin of thread output tensor on global memory
623  // blockwise GEMM c matrix starting index
624  const auto c_thread_mtx_on_block =
625  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
626 
627  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
628  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
629 
630  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
632  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
635 
636  const auto m_thread_data_on_block_idx =
637  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
638  make_multi_index(m_thread_data_on_block));
639 
640  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
645 
646  const auto n_thread_data_on_block_idx =
647  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
648  make_multi_index(n_thread_data_on_block));
649 
650  // VGPR to LDS
651  auto c_thread_copy_vgpr_to_lds =
653  FloatC,
654  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
655  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
657  Sequence<CShuffleMXdlPerWavePerShuffle,
658  CShuffleNXdlPerWavePerShuffle,
659  I1,
660  I1,
661  M2,
662  I1,
663  M4,
664  I1>,
666  7,
667  1,
669  1,
670  true>{
671  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
673  0,
674  m_thread_data_on_block_idx[I1],
675  n_thread_data_on_block_idx[I1],
676  m_thread_data_on_block_idx[I2],
677  m_thread_data_on_block_idx[I3],
678  m_thread_data_on_block_idx[I4],
679  n_thread_data_on_block_idx[I2]),
681 
682  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3<
683  ThisThreadBlock, // ThreadGroup
684  CElementwiseOperation, // ElementwiseOperation,
685  CGlobalMemoryDataOperation, // DstInMemOp,
686  Sequence<1,
687  CShuffleMXdlPerWavePerShuffle,
688  MWave * MPerXdl,
689  1,
690  CShuffleNXdlPerWavePerShuffle,
691  NWave * NPerXdl>, // BlockSliceLengths,
692  CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
693  Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
694  FloatC, // typename Src0Data,
695  FloatC, // typename Src1Data,
696  FloatC, // typename Src2Data,
697  FloatC, // typename DstData,
698  decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
699  decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
700  decltype(c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
701  decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
702  Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
703  5, // index_t VectorDim,
704  CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
705  true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
706  false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
707  false, // bool ThreadTransferSrc2ResetCoordinateAfterRun,
708  false> // bool ThreadTransferDstResetCoordinateAfterRun>
709  {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
710  make_multi_index(0, 0, 0, 0, 0, 0),
711  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
712  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
713  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
714  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
715  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
716  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
717  c_element_op};
718 
719  constexpr auto mxdlperwave_forward_step =
720  make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0);
721  constexpr auto nxdlperwave_forward_step =
722  make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0);
723  constexpr auto nxdlperwave_backward_step =
724  make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0);
725 
726  static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) {
727  constexpr auto mxdlperwave = mxdlperwave_iter;
728 
729  static_for<0,
730  NXdlPerWave,
731  CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) {
732  constexpr bool nxdlperwave_forward_sweep =
733  (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
734 
735  constexpr index_t nxdlperwave_value =
736  nxdlperwave_forward_sweep
737  ? nxdlperwave_iter
738  : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
739 
740  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
741 
742  // make sure it's safe to do ds_write
743  block_sync_lds();
744 
745  // VGPR to LDS
746  c_thread_copy_vgpr_to_lds.Run(
747  c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
748  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
749  c_thread_buf,
750  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
751  c_block_buf);
752 
753  // make sure it's safe to do ds_read
754  block_sync_lds();
755 
756  // LDS to global
757  c_block_copy_lds_to_global.Run(
758  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
759  c_block_buf,
760  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
761  c0_grid_buf,
762  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
763  c1_grid_buf,
764  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
765  c_grid_buf);
766 
767  // move on nxdlperwave dimension
768  if constexpr(nxdlperwave_forward_sweep &&
769  (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
770  {
771  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
772  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
773  nxdlperwave_forward_step);
774 
775  c_block_copy_lds_to_global.MoveSrc2SliceWindow(
776  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
777  nxdlperwave_forward_step);
778 
779  c_block_copy_lds_to_global.MoveDstSliceWindow(
780  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
781  nxdlperwave_forward_step);
782  }
783  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
784  {
785  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
786  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
787  nxdlperwave_backward_step);
788 
789  c_block_copy_lds_to_global.MoveSrc2SliceWindow(
790  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
791  nxdlperwave_backward_step);
792 
793  c_block_copy_lds_to_global.MoveDstSliceWindow(
794  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
795  nxdlperwave_backward_step);
796  }
797  });
798 
799  // move on mxdlperwave dimension
800  if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
801  {
802  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
803  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
804  mxdlperwave_forward_step);
805 
806  c_block_copy_lds_to_global.MoveSrc2SliceWindow(
807  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
808  mxdlperwave_forward_step);
809 
810  c_block_copy_lds_to_global.MoveDstSliceWindow(
811  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
812  mxdlperwave_forward_step);
813  }
814  });
815  }
816  }
817 };
818 
819 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__global__ void kernel_gemm_xdlops_v3r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC *__restrict__ p_c0_grid, const FloatC *__restrict__ p_c1_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r3.hpp:37
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: block_to_ctile_map.hpp:261
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v3r3.hpp:142
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v3r3.hpp:146
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC *__restrict__ p_c0_grid, const FloatC *__restrict__ p_c1_grid, void *__restrict__ p_shared, const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r3.hpp:373
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v3r3.hpp:143
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition: gridwise_gemm_xdlops_v3r3.hpp:347
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdlops_v3r3.hpp:253
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v3r3.hpp:320
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v3r3.hpp:145
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v3r3.hpp:148
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_v3r3.hpp:182
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v3r3.hpp:150
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition: gridwise_gemm_xdlops_v3r3.hpp:369
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v3r3.hpp:311
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v3r3.hpp:223
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r3.hpp:356
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(C1GridDesc_M_N{}))> C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r3.hpp:366
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(C0GridDesc_M_N{}))> C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r3.hpp:361
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_v3r3.hpp:160
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v3r3.hpp:158
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl()
Definition: gridwise_gemm_xdlops_v3r3.hpp:205
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v3r3.hpp:147
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v3r3.hpp:149
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v3r3.hpp:153
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v3r3.hpp:144
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r3.hpp:270
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v3r3.hpp:155
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r3.hpp:40
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334