/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp Source File
gridwise_gemm_xdlops_v3r3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm,
21  typename FloatAB,
22  typename FloatC,
23  typename AGridDesc_K0_M_K1,
24  typename BGridDesc_K0_N_K1,
25  typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
26  typename C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
27  typename C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
28  typename AElementwiseOperation,
29  typename BElementwiseOperation,
30  typename CElementwiseOperation,
31  typename Block2CTileMap,
32  bool HasMainKBlockLoop>
33 __global__ void
34 #if CK_USE_LAUNCH_BOUNDS
36 #endif
38  const FloatAB* __restrict__ p_a_grid,
39  const FloatAB* __restrict__ p_b_grid,
40  FloatC* __restrict__ p_c_grid,
41  const FloatC* __restrict__ p_c0_grid,
42  const FloatC* __restrict__ p_c1_grid,
43  const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
44  const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
45  const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
46  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
47  const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
48  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
49  const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
50  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
51  const AElementwiseOperation a_element_op,
52  const BElementwiseOperation b_element_op,
53  const CElementwiseOperation c_element_op,
54  const Block2CTileMap block_2_ctile_map)
55 {
56 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
57  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
58 
59  GridwiseGemm::template Run<HasMainKBlockLoop>(
60  p_a_grid,
61  p_b_grid,
62  p_c_grid,
63  p_c0_grid,
64  p_c1_grid,
65  p_shared,
66  a_grid_desc_k0_m_k1,
67  b_grid_desc_k0_n_k1,
68  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
69  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
70  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
71  a_element_op,
72  b_element_op,
73  c_element_op,
74  block_2_ctile_map);
75 #else
76  ignore = p_a_grid;
77  ignore = p_b_grid;
78  ignore = p_c_grid;
79  ignore = p_c0_grid;
80  ignore = p_c1_grid;
81  ignore = a_grid_desc_k0_m_k1;
82  ignore = b_grid_desc_k0_n_k1;
83  ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
84  ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
85  ignore = c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
86  ignore = a_element_op;
87  ignore = b_element_op;
88  ignore = c_element_op;
89  ignore = block_2_ctile_map;
90 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
91 }
92 
93 template <
94  index_t BlockSize,
95  typename FloatAB,
96  typename FloatAcc,
97  typename FloatC,
98  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
99  typename AGridDesc_K0_M_K1,
100  typename BGridDesc_K0_N_K1,
101  typename CGridDesc_M_N,
102  typename C0GridDesc_M_N,
103  typename C1GridDesc_M_N,
104  typename AElementwiseOperation,
105  typename BElementwiseOperation,
106  typename CElementwiseOperation,
107  index_t MPerBlock,
108  index_t NPerBlock,
109  index_t K0PerBlock,
110  index_t MPerXdl,
111  index_t NPerXdl,
112  index_t K1Value,
113  index_t MXdlPerWave,
114  index_t NXdlPerWave,
115  typename ABlockTransferThreadClusterLengths_K0_M_K1,
116  typename ABlockTransferThreadClusterArrangeOrder,
117  typename ABlockTransferSrcAccessOrder,
118  index_t ABlockTransferSrcVectorDim,
119  index_t ABlockTransferSrcScalarPerVector,
120  index_t ABlockTransferDstScalarPerVector_K1,
121  bool AThreadTransferSrcResetCoordinateAfterRun,
122  bool ABlockLdsExtraM,
123  typename BBlockTransferThreadClusterLengths_K0_N_K1,
124  typename BBlockTransferThreadClusterArrangeOrder,
125  typename BBlockTransferSrcAccessOrder,
126  index_t BBlockTransferSrcVectorDim,
127  index_t BBlockTransferSrcScalarPerVector,
128  index_t BBlockTransferDstScalarPerVector_K1,
129  bool BThreadTransferSrcResetCoordinateAfterRun,
130  bool BBlockLdsExtraN,
131  index_t CShuffleMXdlPerWavePerShuffle,
132  index_t CShuffleNXdlPerWavePerShuffle,
133  typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
134  index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
135  index_t NumGemmKPrefetchStage = 1,
136  PipelineVersion PipelineVer = PipelineVersion::v1>
138 {
139  static constexpr auto I0 = Number<0>{};
140  static constexpr auto I1 = Number<1>{};
141  static constexpr auto I2 = Number<2>{};
142  static constexpr auto I3 = Number<3>{};
143  static constexpr auto I4 = Number<4>{};
144  static constexpr auto I5 = Number<5>{};
145  static constexpr auto I6 = Number<6>{};
146  static constexpr auto I7 = Number<7>{};
147 
148  // K1 should be Number<...>
149  static constexpr auto K1 = Number<K1Value>{};
150 
152 
154  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
155 
156  __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
157  {
158  constexpr auto max_lds_align = K1;
159 
160  // A matrix in LDS memory, dst of blockwise copy
161  constexpr auto a_block_desc_k0_m_k1 = [&]() {
162  if constexpr(ABlockLdsExtraM)
163  {
167  }
168  else
169  {
171  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
172  }
173  }();
174 
175  return a_block_desc_k0_m_k1;
176  }
177 
178  __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
179  {
180  constexpr auto max_lds_align = K1;
181 
182  // B matrix in LDS memory, dst of blockwise copy
183  constexpr auto b_block_desc_k0_n_k1 = [&]() {
184  if constexpr(BBlockLdsExtraN)
185  {
189  }
190  else
191  {
193  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
194  }
195  }();
196 
197  return b_block_desc_k0_n_k1;
198  }
199 
200  __host__ __device__ static constexpr auto
202  {
203  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
204  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
205 
206  constexpr auto
207  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
209  make_tuple(I1,
212  I1,
215 
216  return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
217  }
218 
219  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
220  {
221  // LDS allocation for A and B: be careful of alignment
222  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
223 
224  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
225 
226  constexpr auto max_lds_align = K1;
227 
228  constexpr auto a_block_space_size_aligned =
229  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
230 
231  constexpr auto b_block_space_size_aligned =
232  math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
233 
234  // LDS allocation for C shuffle in LDS
235  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
237 
238  constexpr auto c_block_size =
239  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
240  .GetElementSpaceSize();
241 
242  return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
243  sizeof(FloatAB),
244  c_block_size * sizeof(FloatC));
245  }
246 
247  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
248  template <typename Block2CTileMap>
249  __host__ __device__ static constexpr bool
250  CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
251  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
252  const CGridDesc_M_N& c_grid_desc_m_n,
253  const Block2CTileMap& block_2_ctile_map)
254  {
255  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
256  "wrong! K1 need to be known at compile-time");
257 
258  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
259  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
260  "Invalid tuning param!");
261 
262  const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
263  const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
264  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
265 
266  if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
267  K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
268  K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
269  return false;
270 
271  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
272  return false;
273 
274  // check gridwise gemm pipeline
275  const auto num_k_loop = K0 / K0PerBlock;
276 
277  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
278  {
279  return false;
280  }
281 
282  if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
283  {
284  return false;
285  }
286 
287  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
288  return true;
289  }
290 
291  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
292  {
293  const index_t num_loop = K / (K0PerBlock * K1);
294 
295  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
296  }
297 
298  template <typename CGridDesc_M_N_>
299  __host__ __device__ static constexpr auto
301  const CGridDesc_M_N_& c_grid_desc_m_n)
302  {
303  const auto M = c_grid_desc_m_n.GetLength(I0);
304  const auto N = c_grid_desc_m_n.GetLength(I1);
305 
306  const auto MBlock = M / MPerBlock;
307  const auto NBlock = N / NPerBlock;
308 
309  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
310  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
311 
312  const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
314  c_grid_desc_m_n,
321 
322  return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
323  }
324 
325  // return block_id to C matrix tile idx (m0, n0) mapping
326  __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
327  const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
328  {
330  c_grid_desc_m_n);
331  }
335  CGridDesc_M_N{}))>;
336 
340  C0GridDesc_M_N{}))>;
341 
345  C1GridDesc_M_N{}))>;
346 
348  remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
349 
350  template <bool HasMainKBlockLoop, typename Block2CTileMap>
351  __device__ static void
352  Run(const FloatAB* __restrict__ p_a_grid,
353  const FloatAB* __restrict__ p_b_grid,
354  FloatC* __restrict__ p_c_grid,
355  const FloatC* __restrict__ p_c0_grid,
356  const FloatC* __restrict__ p_c1_grid,
357  void* __restrict__ p_shared,
358  const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
359  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
361  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
363  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
365  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
366  const AElementwiseOperation& a_element_op,
367  const BElementwiseOperation& b_element_op,
368  const CElementwiseOperation& c_element_op,
369  const Block2CTileMap& block_2_ctile_map)
370  {
371  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
372  p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
373  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
374  p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
375  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
376  p_c_grid,
377  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
378  .GetElementSpaceSize());
379  auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
380  p_c0_grid,
381  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
382  .GetElementSpaceSize());
383  auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
384  p_c1_grid,
385  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
386  .GetElementSpaceSize());
387 
388  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
389 
390  // divide block work by [M, N]
391  const auto block_work_idx =
392  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
393 
394  if(!block_2_ctile_map.ValidCTileIndex(
395  block_work_idx,
396  make_tuple(
397  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
398  .GetLength(I0),
399  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
400  .GetLength(I3))))
401  {
402  return;
403  }
404 
405  // HACK: this force m/n_block_data_idx_on_grid into SGPR
406  const index_t m_block_data_idx_on_grid =
407  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
408 
409  const index_t n_block_data_idx_on_grid =
410  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
411 
412  // lds max alignment
413  constexpr auto max_lds_align = K1;
414 
415  // A matrix in LDS memory, dst of blockwise copy
416  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
417 
418  // B matrix in LDS memory, dst of blockwise copy
419  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
420 
421  // A matrix blockwise copy
422  auto a_blockwise_copy =
424  AElementwiseOperation,
428  ABlockTransferThreadClusterLengths_K0_M_K1,
429  ABlockTransferThreadClusterArrangeOrder,
430  FloatAB,
431  FloatAB,
432  decltype(a_grid_desc_k0_m_k1),
433  decltype(a_block_desc_k0_m_k1),
434  ABlockTransferSrcAccessOrder,
436  ABlockTransferSrcVectorDim,
437  2,
438  ABlockTransferSrcScalarPerVector,
439  ABlockTransferDstScalarPerVector_K1,
440  1,
441  1,
442  AThreadTransferSrcResetCoordinateAfterRun,
443  true>(
444  a_grid_desc_k0_m_k1,
445  make_multi_index(0, m_block_data_idx_on_grid, 0),
446  a_element_op,
447  a_block_desc_k0_m_k1,
448  make_multi_index(0, 0, 0),
450 
451  // B matrix blockwise copy
452  auto b_blockwise_copy =
454  BElementwiseOperation,
458  BBlockTransferThreadClusterLengths_K0_N_K1,
459  BBlockTransferThreadClusterArrangeOrder,
460  FloatAB,
461  FloatAB,
462  decltype(b_grid_desc_k0_n_k1),
463  decltype(b_block_desc_k0_n_k1),
464  BBlockTransferSrcAccessOrder,
466  BBlockTransferSrcVectorDim,
467  2,
468  BBlockTransferSrcScalarPerVector,
469  BBlockTransferDstScalarPerVector_K1,
470  1,
471  1,
472  BThreadTransferSrcResetCoordinateAfterRun,
473  true>(
474  b_grid_desc_k0_n_k1,
475  make_multi_index(0, n_block_data_idx_on_grid, 0),
476  b_element_op,
477  b_block_desc_k0_n_k1,
478  make_multi_index(0, 0, 0),
480 
481  // GEMM definition
482  // c_mtx += transpose(a_mtx) * b_mtx
483  // a_mtx[K0PerBlock, MPerBlock] is in LDS
484  // b_mtx[K0PerBlock, NPerBlock] is in LDS
485  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
486  // register
487  // sanity check
488 
489  auto blockwise_gemm =
491  FloatAB,
492  FloatAB,
493  FloatAcc,
494  decltype(a_block_desc_k0_m_k1),
495  decltype(b_block_desc_k0_n_k1),
496  MPerXdl,
497  NPerXdl,
498  MXdlPerWave,
499  NXdlPerWave,
500  K1>{};
501 
502  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
503 
504  // LDS allocation for A and B: be careful of alignment
505  constexpr auto a_block_space_size_aligned =
506  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
507 
508  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
509  static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
510 
511  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
512  static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
513  b_block_desc_k0_n_k1.GetElementSpaceSize());
514 
515  constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
516  constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
517 
518  // gridwise GEMM pipeline
519  const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
520 
521  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
522  a_block_desc_k0_m_k1,
523  a_blockwise_copy,
524  a_grid_buf,
525  a_block_buf,
526  a_block_slice_copy_step,
527  b_grid_desc_k0_n_k1,
528  b_block_desc_k0_n_k1,
529  b_blockwise_copy,
530  b_grid_buf,
531  b_block_buf,
532  b_block_slice_copy_step,
533  blockwise_gemm,
534  c_thread_buf,
535  K0BlockMainLoop);
536 
537  // shuffle C and write out
538  {
539  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
540  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
541  "wrong!");
542 
543  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
544  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
545 
546  // TODO: hacky, fix it!
547  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
548  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
549 
550  // TODO: hacky, fix it!
551  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
552  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
553  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
554 
555  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
556  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
557  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
558  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
559  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
560  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
561  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
562  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
563 
564  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
566 
567  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
568  static_cast<FloatC*>(p_shared),
569  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
570  .GetElementSpaceSize());
571 
572  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
573  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
574  make_tuple(make_freeze_transform(I0), // freeze mblock
576  Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per
577  // shuffle
579  make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
580  make_freeze_transform(I0), // freeze nblock
582  Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per
583  // shuffle
585  make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
587  Sequence<1>{},
588  Sequence<2>{},
589  Sequence<3>{},
590  Sequence<4>{},
591  Sequence<5>{}),
593  Sequence<0>{},
595  Sequence<>{},
596  Sequence<1>{},
597  Sequence<3, 7>{})
598 
599  );
600 
601  // calculate origin of thread output tensor on global memory
602  // blockwise GEMM c matrix starting index
603  const auto c_thread_mtx_on_block =
604  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
605 
606  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
607  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
608 
609  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
611  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
614 
615  const auto m_thread_data_on_block_idx =
616  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
617  make_multi_index(m_thread_data_on_block));
618 
619  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
624 
625  const auto n_thread_data_on_block_idx =
626  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
627  make_multi_index(n_thread_data_on_block));
628 
629  // VGPR to LDS
630  auto c_thread_copy_vgpr_to_lds =
632  FloatC,
633  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
634  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
636  Sequence<CShuffleMXdlPerWavePerShuffle,
637  CShuffleNXdlPerWavePerShuffle,
638  I1,
639  I1,
640  M2,
641  I1,
642  M4,
643  I1>,
645  7,
646  1,
648  1,
649  true>{
650  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
652  0,
653  m_thread_data_on_block_idx[I1],
654  n_thread_data_on_block_idx[I1],
655  m_thread_data_on_block_idx[I2],
656  m_thread_data_on_block_idx[I3],
657  m_thread_data_on_block_idx[I4],
658  n_thread_data_on_block_idx[I2]),
660 
661  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3<
662  ThisThreadBlock, // ThreadGroup
663  CElementwiseOperation, // ElementwiseOperation,
664  CGlobalMemoryDataOperation, // DstInMemOp,
665  Sequence<1,
666  CShuffleMXdlPerWavePerShuffle,
667  MWave * MPerXdl,
668  1,
669  CShuffleNXdlPerWavePerShuffle,
670  NWave * NPerXdl>, // BlockSliceLengths,
671  CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
672  Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
673  FloatC, // typename Src0Data,
674  FloatC, // typename Src1Data,
675  FloatC, // typename Src2Data,
676  FloatC, // typename DstData,
677  decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
678  decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
679  decltype(c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
680  decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
681  Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
682  5, // index_t VectorDim,
683  CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
684  true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
685  false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
686  false, // bool ThreadTransferSrc2ResetCoordinateAfterRun,
687  false> // bool ThreadTransferDstResetCoordinateAfterRun>
688  {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
689  make_multi_index(0, 0, 0, 0, 0, 0),
690  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
691  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
692  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
693  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
694  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
695  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
696  c_element_op};
697 
698  constexpr auto mxdlperwave_forward_step =
699  make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0);
700  constexpr auto nxdlperwave_forward_step =
701  make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0);
702  constexpr auto nxdlperwave_backward_step =
703  make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0);
704 
705  static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) {
706  constexpr auto mxdlperwave = mxdlperwave_iter;
707 
708  static_for<0,
709  NXdlPerWave,
710  CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) {
711  constexpr bool nxdlperwave_forward_sweep =
712  (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
713 
714  constexpr index_t nxdlperwave_value =
715  nxdlperwave_forward_sweep
716  ? nxdlperwave_iter
717  : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
718 
719  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
720 
721  // make sure it's safe to do ds_write
722  block_sync_lds();
723 
724  // VGPR to LDS
725  c_thread_copy_vgpr_to_lds.Run(
726  c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
727  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
728  c_thread_buf,
729  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
730  c_block_buf);
731 
732  // make sure it's safe to do ds_read
733  block_sync_lds();
734 
735  // LDS to global
736  c_block_copy_lds_to_global.Run(
737  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
738  c_block_buf,
739  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
740  c0_grid_buf,
741  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
742  c1_grid_buf,
743  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
744  c_grid_buf);
745 
746  // move on nxdlperwave dimension
747  if constexpr(nxdlperwave_forward_sweep &&
748  (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
749  {
750  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
751  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
752  nxdlperwave_forward_step);
753 
754  c_block_copy_lds_to_global.MoveSrc2SliceWindow(
755  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
756  nxdlperwave_forward_step);
757 
758  c_block_copy_lds_to_global.MoveDstSliceWindow(
759  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
760  nxdlperwave_forward_step);
761  }
762  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
763  {
764  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
765  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
766  nxdlperwave_backward_step);
767 
768  c_block_copy_lds_to_global.MoveSrc2SliceWindow(
769  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
770  nxdlperwave_backward_step);
771 
772  c_block_copy_lds_to_global.MoveDstSliceWindow(
773  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
774  nxdlperwave_backward_step);
775  }
776  });
777 
778  // move on mxdlperwave dimension
779  if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
780  {
781  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
782  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
783  mxdlperwave_forward_step);
784 
785  c_block_copy_lds_to_global.MoveSrc2SliceWindow(
786  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
787  mxdlperwave_forward_step);
788 
789  c_block_copy_lds_to_global.MoveDstSliceWindow(
790  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
791  mxdlperwave_forward_step);
792  }
793  });
794  }
795  }
796 };
797 
798 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__global__ void kernel_gemm_xdlops_v3r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC *__restrict__ p_c0_grid, const FloatC *__restrict__ p_c1_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r3.hpp:37
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: block_to_ctile_map.hpp:260
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v3r3.hpp:138
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v3r3.hpp:142
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC *__restrict__ p_c0_grid, const FloatC *__restrict__ p_c1_grid, void *__restrict__ p_shared, const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r3.hpp:352
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v3r3.hpp:139
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition: gridwise_gemm_xdlops_v3r3.hpp:326
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v3r3.hpp:300
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v3r3.hpp:141
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v3r3.hpp:144
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_v3r3.hpp:178
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v3r3.hpp:146
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition: gridwise_gemm_xdlops_v3r3.hpp:348
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v3r3.hpp:291
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v3r3.hpp:219
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r3.hpp:335
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(C1GridDesc_M_N{}))> C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r3.hpp:345
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(C0GridDesc_M_N{}))> C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r3.hpp:340
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_v3r3.hpp:156
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v3r3.hpp:154
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl()
Definition: gridwise_gemm_xdlops_v3r3.hpp:201
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v3r3.hpp:143
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v3r3.hpp:145
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v3r3.hpp:149
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v3r3.hpp:140
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r3.hpp:250
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v3r3.hpp:151
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r3.hpp:40
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334