/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp Source File
gridwise_gemm_xdlops_v3r2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm,
21  typename FloatAB,
22  typename FloatC,
23  typename AGridDesc_K0_M_K1,
24  typename BGridDesc_K0_N_K1,
25  typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
26  typename C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
27  typename AElementwiseOperation,
28  typename BElementwiseOperation,
29  typename CElementwiseOperation,
30  typename Block2CTileMap,
31  bool HasMainKBlockLoop>
32 __global__ void
33 #if CK_USE_LAUNCH_BOUNDS
35 #endif
37  const FloatAB* __restrict__ p_a_grid,
38  const FloatAB* __restrict__ p_b_grid,
39  FloatC* __restrict__ p_c_grid,
40  const FloatC* __restrict__ p_c0_grid,
41  const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
42  const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
43  const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
44  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
45  const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
46  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
47  const AElementwiseOperation a_element_op,
48  const BElementwiseOperation b_element_op,
49  const CElementwiseOperation c_element_op,
50  const Block2CTileMap block_2_ctile_map)
51 {
52 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
53  defined(__gfx12__)
54  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
55  {
56  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
57 
58  GridwiseGemm::template Run<HasMainKBlockLoop>(
59  p_a_grid,
60  p_b_grid,
61  p_c_grid,
62  p_c0_grid,
63  p_shared,
64  a_grid_desc_k0_m_k1,
65  b_grid_desc_k0_n_k1,
66  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
67  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
68  a_element_op,
69  b_element_op,
70  c_element_op,
71  block_2_ctile_map);
72  }
73 #else
74  ignore = p_a_grid;
75  ignore = p_b_grid;
76  ignore = p_c_grid;
77  ignore = p_c0_grid;
78  ignore = a_grid_desc_k0_m_k1;
79  ignore = b_grid_desc_k0_n_k1;
80  ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
81  ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
82  ignore = a_element_op;
83  ignore = b_element_op;
84  ignore = c_element_op;
85  ignore = block_2_ctile_map;
86 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
87 }
88 
89 template <
90  index_t BlockSize,
91  typename FloatAB,
92  typename FloatAcc,
93  typename FloatC,
94  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
95  typename AGridDesc_K0_M_K1,
96  typename BGridDesc_K0_N_K1,
97  typename CGridDesc_M_N,
98  typename C0GridDesc_M_N,
99  typename AElementwiseOperation,
100  typename BElementwiseOperation,
101  typename CElementwiseOperation,
102  index_t MPerBlock,
103  index_t NPerBlock,
104  index_t K0PerBlock,
105  index_t MPerXdl,
106  index_t NPerXdl,
107  index_t K1Value,
108  index_t MXdlPerWave,
109  index_t NXdlPerWave,
110  typename ABlockTransferThreadClusterLengths_K0_M_K1,
111  typename ABlockTransferThreadClusterArrangeOrder,
112  typename ABlockTransferSrcAccessOrder,
113  index_t ABlockTransferSrcVectorDim,
114  index_t ABlockTransferSrcScalarPerVector,
115  index_t ABlockTransferDstScalarPerVector_K1,
116  bool AThreadTransferSrcResetCoordinateAfterRun,
117  bool ABlockLdsExtraM,
118  typename BBlockTransferThreadClusterLengths_K0_N_K1,
119  typename BBlockTransferThreadClusterArrangeOrder,
120  typename BBlockTransferSrcAccessOrder,
121  index_t BBlockTransferSrcVectorDim,
122  index_t BBlockTransferSrcScalarPerVector,
123  index_t BBlockTransferDstScalarPerVector_K1,
124  bool BThreadTransferSrcResetCoordinateAfterRun,
125  bool BBlockLdsExtraN,
126  index_t CShuffleMXdlPerWavePerShuffle,
127  index_t CShuffleNXdlPerWavePerShuffle,
128  typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
129  index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
130  index_t NumGemmKPrefetchStage = 1,
131  PipelineVersion PipelineVer = PipelineVersion::v1>
133 {
134  static constexpr auto I0 = Number<0>{};
135  static constexpr auto I1 = Number<1>{};
136  static constexpr auto I2 = Number<2>{};
137  static constexpr auto I3 = Number<3>{};
138  static constexpr auto I4 = Number<4>{};
139  static constexpr auto I5 = Number<5>{};
140  static constexpr auto I6 = Number<6>{};
141  static constexpr auto I7 = Number<7>{};
142 
143  // K1 should be Number<...>
144  static constexpr auto K1 = Number<K1Value>{};
145 
147 
149  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
150 
151  __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
152  {
153  constexpr auto max_lds_align = K1;
154 
155  // A matrix in LDS memory, dst of blockwise copy
156  constexpr auto a_block_desc_k0_m_k1 = [&]() {
157  if constexpr(ABlockLdsExtraM)
158  {
162  }
163  else
164  {
166  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
167  }
168  }();
169 
170  return a_block_desc_k0_m_k1;
171  }
172 
173  __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
174  {
175  constexpr auto max_lds_align = K1;
176 
177  // B matrix in LDS memory, dst of blockwise copy
178  constexpr auto b_block_desc_k0_n_k1 = [&]() {
179  if constexpr(BBlockLdsExtraN)
180  {
184  }
185  else
186  {
188  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
189  }
190  }();
191 
192  return b_block_desc_k0_n_k1;
193  }
194 
195  __host__ __device__ static constexpr auto
197  {
198  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
199  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
200 
201  constexpr auto
202  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
204  make_tuple(I1,
207  I1,
210 
211  return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
212  }
213 
214  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
215  {
216  // LDS allocation for A and B: be careful of alignment
217  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
218 
219  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
220 
221  constexpr auto max_lds_align = K1;
222 
223  constexpr auto a_block_space_size_aligned =
224  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
225 
226  constexpr auto b_block_space_size_aligned =
227  math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
228 
229  // LDS allocation for C shuffle in LDS
230  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
232 
233  constexpr auto c_block_size =
234  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
235  .GetElementSpaceSize();
236 
237  return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
238  sizeof(FloatAB),
239  c_block_size * sizeof(FloatC));
240  }
241 
242  template <
243  InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
244  __device__ static bool constexpr IsValidCompilationParameter()
245  {
246  return ck::tensor_operation::device::IsValidGemmCompilationParameter<
247  BlockSize,
248  MPerBlock,
249  NPerBlock,
250  MPerXdl,
251  NPerXdl,
252  MXdlPerWave,
253  NXdlPerWave,
254  FloatC,
255  CGlobalMemoryDataOperation>();
256  }
257 
258  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
259  template <typename Block2CTileMap>
260  __host__ __device__ static constexpr bool
261  CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
262  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
263  const CGridDesc_M_N& c_grid_desc_m_n,
264  const Block2CTileMap& block_2_ctile_map)
265  {
266  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
267  "wrong! K1 need to be known at compile-time");
268 
269  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
270  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
271  "Invalid tuning param!");
272 
273  const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
274  const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
275  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
276 
277  if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
278  K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
279  K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
280  return false;
281 
282  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
283  return false;
284 
285  // check gridwise gemm pipeline
286  const auto num_k_loop = K0 / K0PerBlock;
287 
288  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
289  {
290  return false;
291  }
292 
293  if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
294  {
295  return false;
296  }
297 
298  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
299  return true;
300  }
301 
302  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
303  {
304  const index_t num_loop = K / (K0PerBlock * K1);
305 
306  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
307  }
308 
309  template <typename CGridDesc_M_N_>
310  __host__ __device__ static constexpr auto
312  const CGridDesc_M_N_& c_grid_desc_m_n)
313  {
314  const auto M = c_grid_desc_m_n.GetLength(I0);
315  const auto N = c_grid_desc_m_n.GetLength(I1);
316 
317  const auto MBlock = M / MPerBlock;
318  const auto NBlock = N / NPerBlock;
319 
320  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
321  constexpr index_t NWave =
322  NXdlPerWave * NPerXdl == 0 ? 1 : NPerBlock / (NXdlPerWave * NPerXdl);
323 
324  const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
326  c_grid_desc_m_n,
333 
334  return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
335  }
336 
337  // return block_id to C matrix tile idx (m0, n0) mapping
338  __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
339  const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
340  {
342  c_grid_desc_m_n);
343  }
344 
348  CGridDesc_M_N{}))>;
349 
353  C0GridDesc_M_N{}))>;
354 
356  remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
357 
358  template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
359  __device__ static void
360  Run(const FloatAB* __restrict__ p_a_grid,
361  const FloatAB* __restrict__ p_b_grid,
362  FloatC* __restrict__ p_c_grid,
363  const FloatC* __restrict__ p_c0_grid,
364  void* __restrict__ p_shared,
365  const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
366  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
368  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
370  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
371  const AElementwiseOperation& a_element_op,
372  const BElementwiseOperation& b_element_op,
373  const CElementwiseOperation& c_element_op,
374  const Block2CTileMap& block_2_ctile_map)
375  {
376  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
377  p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
378  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
379  p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
380  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
381  p_c_grid,
382  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
383  .GetElementSpaceSize());
384  auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
385  p_c0_grid,
386  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
387  .GetElementSpaceSize());
388 
389  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
390 
391  // divide block work by [M, N]
392  const auto block_work_idx =
393  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
394 
395  if(!block_2_ctile_map.ValidCTileIndex(
396  block_work_idx,
397  make_tuple(
398  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
399  .GetLength(I0),
400  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
401  .GetLength(I3))))
402  {
403  return;
404  }
405 
406  // HACK: this force m/n_block_data_idx_on_grid into SGPR
407  const index_t m_block_data_idx_on_grid =
408  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
409 
410  const index_t n_block_data_idx_on_grid =
411  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
412 
413  // lds max alignment
414  constexpr auto max_lds_align = K1;
415 
416  // A matrix in LDS memory, dst of blockwise copy
417  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
418 
419  // B matrix in LDS memory, dst of blockwise copy
420  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
421 
422  // A matrix blockwise copy
423  auto a_blockwise_copy =
425  AElementwiseOperation,
429  ABlockTransferThreadClusterLengths_K0_M_K1,
430  ABlockTransferThreadClusterArrangeOrder,
431  FloatAB,
432  FloatAB,
433  decltype(a_grid_desc_k0_m_k1),
434  decltype(a_block_desc_k0_m_k1),
435  ABlockTransferSrcAccessOrder,
437  ABlockTransferSrcVectorDim,
438  2,
439  ABlockTransferSrcScalarPerVector,
440  ABlockTransferDstScalarPerVector_K1,
441  1,
442  1,
443  AThreadTransferSrcResetCoordinateAfterRun,
444  true,
445  NumGemmKPrefetchStage>(
446  a_grid_desc_k0_m_k1,
447  make_multi_index(0, m_block_data_idx_on_grid, 0),
448  a_element_op,
449  a_block_desc_k0_m_k1,
450  make_multi_index(0, 0, 0),
452 
453  // B matrix blockwise copy
454  auto b_blockwise_copy =
456  BElementwiseOperation,
460  BBlockTransferThreadClusterLengths_K0_N_K1,
461  BBlockTransferThreadClusterArrangeOrder,
462  FloatAB,
463  FloatAB,
464  decltype(b_grid_desc_k0_n_k1),
465  decltype(b_block_desc_k0_n_k1),
466  BBlockTransferSrcAccessOrder,
468  BBlockTransferSrcVectorDim,
469  2,
470  BBlockTransferSrcScalarPerVector,
471  BBlockTransferDstScalarPerVector_K1,
472  1,
473  1,
474  BThreadTransferSrcResetCoordinateAfterRun,
475  true,
476  NumGemmKPrefetchStage>(
477  b_grid_desc_k0_n_k1,
478  make_multi_index(0, n_block_data_idx_on_grid, 0),
479  b_element_op,
480  b_block_desc_k0_n_k1,
481  make_multi_index(0, 0, 0),
483 
484  // GEMM definition
485  // c_mtx += transpose(a_mtx) * b_mtx
486  // a_mtx[K0PerBlock, MPerBlock] is in LDS
487  // b_mtx[K0PerBlock, NPerBlock] is in LDS
488  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
489  // register
490  // sanity check
491 
492  auto blockwise_gemm =
494  FloatAB,
495  FloatAB,
496  FloatAcc,
497  decltype(a_block_desc_k0_m_k1),
498  decltype(b_block_desc_k0_n_k1),
499  MPerXdl,
500  NPerXdl,
501  MXdlPerWave,
502  NXdlPerWave,
503  K1>{};
504 
505  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
506 
507  // LDS allocation for A and B: be careful of alignment
508  constexpr auto a_block_space_size_aligned =
509  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
510 
511  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
512  static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
513 
514  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
515  static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
516  b_block_desc_k0_n_k1.GetElementSpaceSize());
517 
518  constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
519  constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
520 
521  // gridwise GEMM pipeline
522  const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
523 
524  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
525  a_block_desc_k0_m_k1,
526  a_blockwise_copy,
527  a_grid_buf,
528  a_block_buf,
529  a_block_slice_copy_step,
530  b_grid_desc_k0_n_k1,
531  b_block_desc_k0_n_k1,
532  b_blockwise_copy,
533  b_grid_buf,
534  b_block_buf,
535  b_block_slice_copy_step,
536  blockwise_gemm,
537  c_thread_buf,
538  K0BlockMainLoop);
539 
540  // shuffle C and write out
541  {
542  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
543  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
544  "wrong!");
545 
546  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
547  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
548 
549  // TODO: hacky, fix it!
550  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
551  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
552 
553  // TODO: hacky, fix it!
554  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
555  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
556  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
557 
558  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
559  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
560  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
561  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
562  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
563  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
564  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
565  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
566 
567  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
569 
570  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
571  static_cast<FloatC*>(p_shared),
572  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
573  .GetElementSpaceSize());
574 
575  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
576  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
577  make_tuple(
578  make_freeze_transform(I0), // freeze mblock
580  Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per shuffle
582  make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
583  make_freeze_transform(I0), // freeze nblock
585  Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per shuffle
587  make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
589  Sequence<1>{},
590  Sequence<2>{},
591  Sequence<3>{},
592  Sequence<4>{},
593  Sequence<5>{}),
595  Sequence<0>{},
597  Sequence<>{},
598  Sequence<1>{},
599  Sequence<3, 7>{})
600 
601  );
602 
603  // calculate origin of thread output tensor on global memory
604  // blockwise GEMM c matrix starting index
605  const auto c_thread_mtx_on_block =
606  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
607 
608  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
609  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
610 
611  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
613  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
616 
617  const auto m_thread_data_on_block_idx =
618  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
619  make_multi_index(m_thread_data_on_block));
620 
621  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
626 
627  const auto n_thread_data_on_block_idx =
628  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
629  make_multi_index(n_thread_data_on_block));
630 
631  // VGPR to LDS
632  auto c_thread_copy_vgpr_to_lds =
634  FloatC,
635  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
636  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
638  Sequence<CShuffleMXdlPerWavePerShuffle,
639  CShuffleNXdlPerWavePerShuffle,
640  I1,
641  I1,
642  M2,
643  I1,
644  M4,
645  I1>,
647  7,
648  1,
650  1,
651  true>{
652  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
654  0,
655  m_thread_data_on_block_idx[I1],
656  n_thread_data_on_block_idx[I1],
657  m_thread_data_on_block_idx[I2],
658  m_thread_data_on_block_idx[I3],
659  m_thread_data_on_block_idx[I4],
660  n_thread_data_on_block_idx[I2]),
662 
663  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r2<
664  ThisThreadBlock, // index_t BlockSize,
665  CElementwiseOperation, // ElementwiseOperation,
666  CGlobalMemoryDataOperation, // DstInMemOp,
667  Sequence<1,
668  CShuffleMXdlPerWavePerShuffle,
669  MWave * MPerXdl,
670  1,
671  CShuffleNXdlPerWavePerShuffle,
672  NWave * NPerXdl>, // BlockSliceLengths,
673  CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
674  Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
675  FloatC, // typename Src0Data,
676  FloatC, // typename Src1Data,
677  FloatC, // typename DstData,
678  decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
679  decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
680  decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
681  Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
682  5, // index_t VectorDim,
683  CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
684  true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
685  false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
686  false> // bool ThreadTransferDstResetCoordinateAfterRun>
687  {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
688  make_multi_index(0, 0, 0, 0, 0, 0),
689  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
690  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
691  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
692  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
693  c_element_op};
694 
695  constexpr auto mxdlperwave_forward_step =
696  make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0);
697  constexpr auto nxdlperwave_forward_step =
698  make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0);
699  constexpr auto nxdlperwave_backward_step =
700  make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0);
701 
702  static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) {
703  constexpr auto mxdlperwave = mxdlperwave_iter;
704 
705  static_for<0,
706  NXdlPerWave,
707  CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) {
708  constexpr bool nxdlperwave_forward_sweep =
709  (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
710 
711  constexpr index_t nxdlperwave_value =
712  nxdlperwave_forward_sweep
713  ? nxdlperwave_iter
714  : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
715 
716  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
717 
718  // make sure it's safe to do ds_write
719  block_sync_lds();
720 
721  // VGPR to LDS
722  c_thread_copy_vgpr_to_lds.Run(
723  c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
724  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
725  c_thread_buf,
726  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
727  c_block_buf);
728 
729  // make sure it's safe to do ds_read
730  block_sync_lds();
731 
732  // LDS to global
733  c_block_copy_lds_to_global.Run(
734  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
735  c_block_buf,
736  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
737  c0_grid_buf,
738  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
739  c_grid_buf);
740 
741  // move on nxdlperwave dimension
742  if constexpr(nxdlperwave_forward_sweep &&
743  (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
744  {
745  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
746  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
747  nxdlperwave_forward_step);
748 
749  c_block_copy_lds_to_global.MoveDstSliceWindow(
750  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
751  nxdlperwave_forward_step);
752  }
753  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
754  {
755  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
756  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
757  nxdlperwave_backward_step);
758 
759  c_block_copy_lds_to_global.MoveDstSliceWindow(
760  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
761  nxdlperwave_backward_step);
762  }
763  });
764 
765  // move on mxdlperwave dimension
766  if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
767  {
768  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
769  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
770  mxdlperwave_forward_step);
771 
772  c_block_copy_lds_to_global.MoveDstSliceWindow(
773  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
774  mxdlperwave_forward_step);
775  }
776  });
777  }
778  }
779 };
780 
781 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__global__ void kernel_gemm_xdlops_v3r2(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC *__restrict__ p_c0_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r2.hpp:36
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: block_to_ctile_map.hpp:261
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v3r2.hpp:133
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v3r2.hpp:135
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v3r2.hpp:302
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(C0GridDesc_M_N{}))> C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r2.hpp:353
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v3r2.hpp:144
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdlops_v3r2.hpp:244
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v3r2.hpp:214
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_v3r2.hpp:151
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_v3r2.hpp:173
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v3r2.hpp:137
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v3r2.hpp:138
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v3r2.hpp:311
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition: gridwise_gemm_xdlops_v3r2.hpp:338
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v3r2.hpp:134
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v3r2.hpp:149
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC *__restrict__ p_c0_grid, void *__restrict__ p_shared, const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r2.hpp:360
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v3r2.hpp:136
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v3r2.hpp:146
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r2.hpp:261
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v3r2.hpp:139
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v3r2.hpp:140
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl()
Definition: gridwise_gemm_xdlops_v3r2.hpp:196
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v3r2.hpp:141
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r2.hpp:348
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition: gridwise_gemm_xdlops_v3r2.hpp:356
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r2.hpp:37
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334