/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp Source File
gridwise_gemm_xdlops_v3r1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
18 
19 namespace ck {
20 
21 template <typename GridwiseGemm,
22  typename FloatAB,
23  typename FloatC,
24  typename AGridDesc_AK0_M_AK1,
25  typename BGridDesc_BK0_N_BK1,
26  typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
27  typename AElementwiseOperation,
28  typename BElementwiseOperation,
29  typename CElementwiseOperation,
30  typename Block2CTileMap,
31  bool HasMainK0BlockLoop>
32 __global__ void
33 #if CK_USE_LAUNCH_BOUNDS
35 #endif
37  const FloatAB* __restrict__ p_a_grid,
38  const FloatAB* __restrict__ p_b_grid,
39  FloatC* __restrict__ p_c_grid,
40  const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
41  const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
42  const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
43  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
44  const AElementwiseOperation a_element_op,
45  const BElementwiseOperation b_element_op,
46  const CElementwiseOperation c_element_op,
47  const Block2CTileMap block_2_ctile_map)
48 {
49 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
50  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
51 
52  GridwiseGemm::template Run<HasMainK0BlockLoop>(
53  p_a_grid,
54  p_b_grid,
55  p_c_grid,
56  p_shared,
57  a_grid_desc_ak0_m_ak1,
58  b_grid_desc_bk0_n_bk1,
59  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
60  a_element_op,
61  b_element_op,
62  c_element_op,
63  block_2_ctile_map);
64 #else
65  ignore = p_a_grid;
66  ignore = p_b_grid;
67  ignore = p_c_grid;
68  ignore = a_grid_desc_ak0_m_ak1;
69  ignore = b_grid_desc_bk0_n_bk1;
70  ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
71  ignore = a_element_op;
72  ignore = b_element_op;
73  ignore = c_element_op;
74  ignore = block_2_ctile_map;
75 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
76 }
77 
78 template <
79  index_t BlockSize,
80  typename FloatAB,
81  typename FloatAcc,
82  typename FloatCShuffle,
83  typename FloatC,
84  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
85  typename AGridDesc_AK0_M_AK1,
86  typename BGridDesc_BK0_N_BK1,
87  typename CGridDesc_M_N,
88  typename AElementwiseOperation,
89  typename BElementwiseOperation,
90  typename CElementwiseOperation,
91  index_t MPerBlock,
92  index_t NPerBlock,
93  index_t KPerBlock,
94  index_t AK1Value,
95  index_t BK1Value,
96  index_t MPerXdl,
97  index_t NPerXdl,
98  index_t MXdlPerWave,
99  index_t NXdlPerWave,
100  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
101  typename ABlockTransferThreadClusterArrangeOrder,
102  typename ABlockTransferSrcAccessOrder,
103  index_t ABlockTransferSrcVectorDim,
104  index_t ABlockTransferSrcScalarPerVector,
105  index_t ABlockTransferDstScalarPerVector_K1,
106  bool AThreadTransferSrcResetCoordinateAfterRun,
107  bool ABlockLdsExtraM,
108  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
109  typename BBlockTransferThreadClusterArrangeOrder,
110  typename BBlockTransferSrcAccessOrder,
111  index_t BBlockTransferSrcVectorDim,
112  index_t BBlockTransferSrcScalarPerVector,
113  index_t BBlockTransferDstScalarPerVector_K1,
114  bool BThreadTransferSrcResetCoordinateAfterRun,
115  bool BBlockLdsExtraN,
116  index_t CShuffleMXdlPerWavePerShuffle,
117  index_t CShuffleNXdlPerWavePerShuffle,
118  typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
119  index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
120  index_t NumGemmKPrefetchStage = 1,
121  PipelineVersion PipelineVer = PipelineVersion::v1>
123 {
124  static constexpr auto I0 = Number<0>{};
125  static constexpr auto I1 = Number<1>{};
126  static constexpr auto I2 = Number<2>{};
127  static constexpr auto I3 = Number<3>{};
128  static constexpr auto I4 = Number<4>{};
129  static constexpr auto I5 = Number<5>{};
130  static constexpr auto I6 = Number<6>{};
131  static constexpr auto I7 = Number<7>{};
132 
133  // K1 should be Number<...>
134  static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
135  static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
136  static constexpr auto AK1 = Number<AK1Value>{};
137  static constexpr auto BK1 = Number<BK1Value>{};
138 
140 
142  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
143 
144  __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
145  {
146  constexpr auto max_lds_align = AK1;
147 
148  // A matrix in LDS memory, dst of blockwise copy
149  constexpr auto a_block_desc_ak0_m_ak1 = [&]() {
150  if constexpr(ABlockLdsExtraM)
151  {
155  }
156  else
157  {
159  make_tuple(AK0, Number<MPerBlock>{}, AK1), max_lds_align);
160  }
161  }();
162 
163  return a_block_desc_ak0_m_ak1;
164  }
165 
166  __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
167  {
168  constexpr auto max_lds_align = BK1;
169 
170  // B matrix in LDS memory, dst of blockwise copy
171  constexpr auto b_block_desc_bk0_n_bk1 = [&]() {
172  if constexpr(BBlockLdsExtraN)
173  {
177  }
178  else
179  {
181  make_tuple(BK0, Number<NPerBlock>{}, BK1), max_lds_align);
182  }
183  }();
184 
185  return b_block_desc_bk0_n_bk1;
186  }
187 
188  __host__ __device__ static constexpr auto
190  {
191  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
192  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
193 
194  constexpr auto
195  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
197  make_tuple(I1,
200  I1,
203 
204  return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
205  }
206 
207  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
208  {
209  // LDS allocation for A and B: be careful of alignment
210  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
211 
212  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
213 
214  constexpr auto a_block_space_size_aligned =
215  math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), AK1);
216 
217  constexpr auto b_block_space_size_aligned =
218  math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), BK1);
219 
220  // LDS allocation for C shuffle in LDS
221  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
223 
224  constexpr auto c_block_size =
225  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
226  .GetElementSpaceSize();
227 
228  return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
229  sizeof(FloatAB),
230  c_block_size * sizeof(FloatCShuffle));
231  }
232 
233  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
234  template <typename Block2CTileMap>
235  __host__ __device__ static constexpr bool
236  CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
237  const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
238  const CGridDesc_M_N& c_grid_desc_m_n,
239  const Block2CTileMap& block_2_ctile_map)
240  {
241  // static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
242  // is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
243  // "wrong! K1 need to be known at compile-time");
244 
245  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
246  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
247  "Invalid tuning param!");
248 
249  const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
250  const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
251  const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
252 
253  if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
254  return false;
255 
256  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
257  return false;
258 
259  // check gridwise gemm pipeline
260  const auto num_k_loop = K / KPerBlock;
261 
262  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
263  {
264  return false;
265  }
266 
267  if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
268  {
269  return false;
270  }
271 
272  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
273  return true;
274  }
275 
276  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
277  {
278  const index_t num_loop = K / KPerBlock;
279 
280  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
281  }
282 
283  __host__ __device__ static constexpr auto
285  const CGridDesc_M_N& c_grid_desc_m_n)
286  {
287  const auto M = c_grid_desc_m_n.GetLength(I0);
288  const auto N = c_grid_desc_m_n.GetLength(I1);
289 
290  const auto MBlock = M / MPerBlock;
291  const auto NBlock = N / NPerBlock;
292 
293  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
294  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
295 
296  const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
298  c_grid_desc_m_n,
305 
306  return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
307  }
308 
309  // return block_id to C matrix tile idx (m0, n0) mapping
310  __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
311  const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
312  {
314  c_grid_desc_m_n);
315  }
319  CGridDesc_M_N{}))>;
320 
322  remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
323 
324  template <bool HasMainK0BlockLoop, typename Block2CTileMap>
325  __device__ static void
326  Run(const FloatAB* __restrict__ p_a_grid,
327  const FloatAB* __restrict__ p_b_grid,
328  FloatC* __restrict__ p_c_grid,
329  void* __restrict__ p_shared,
330  const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
331  const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
333  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
334  const AElementwiseOperation& a_element_op,
335  const BElementwiseOperation& b_element_op,
336  const CElementwiseOperation& c_element_op,
337  const Block2CTileMap& block_2_ctile_map)
338  {
339  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
340  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
341  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
342  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
343  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
344  p_c_grid,
345  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
346  .GetElementSpaceSize());
347 
348  // divide block work by [M, N]
349  const auto block_work_idx =
350  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
351 
352  if(!block_2_ctile_map.ValidCTileIndex(
353  block_work_idx,
354  make_tuple(
355  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
356  .GetLength(I0),
357  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
358  .GetLength(I3))))
359  {
360  return;
361  }
362 
363  // HACK: this force m/n_block_data_idx_on_grid into SGPR
364  const index_t m_block_data_idx_on_grid =
365  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
366 
367  const index_t n_block_data_idx_on_grid =
368  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
369 
370  // lds max alignment
371  constexpr auto max_lds_align = math::lcm(AK1, BK1);
372 
373  // A matrix in LDS memory, dst of blockwise copy
374  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
375 
376  // B matrix in LDS memory, dst of blockwise copy
377  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
378 
379  // A matrix blockwise copy
380  auto a_blockwise_copy =
382  AElementwiseOperation,
386  ABlockTransferThreadClusterLengths_AK0_M_AK1,
387  ABlockTransferThreadClusterArrangeOrder,
388  FloatAB,
389  FloatAB,
390  decltype(a_grid_desc_ak0_m_ak1),
391  decltype(a_block_desc_ak0_m_ak1),
392  ABlockTransferSrcAccessOrder,
394  ABlockTransferSrcVectorDim,
395  2,
396  ABlockTransferSrcScalarPerVector,
397  ABlockTransferDstScalarPerVector_K1,
398  1,
399  1,
400  AThreadTransferSrcResetCoordinateAfterRun,
401  true,
402  NumGemmKPrefetchStage>(
403  a_grid_desc_ak0_m_ak1,
404  make_multi_index(0, m_block_data_idx_on_grid, 0),
405  a_element_op,
406  a_block_desc_ak0_m_ak1,
407  make_multi_index(0, 0, 0),
409 
410  // B matrix blockwise copy
411  auto b_blockwise_copy =
413  BElementwiseOperation,
417  BBlockTransferThreadClusterLengths_BK0_N_BK1,
418  BBlockTransferThreadClusterArrangeOrder,
419  FloatAB,
420  FloatAB,
421  decltype(b_grid_desc_bk0_n_bk1),
422  decltype(b_block_desc_bk0_n_bk1),
423  BBlockTransferSrcAccessOrder,
425  BBlockTransferSrcVectorDim,
426  2,
427  BBlockTransferSrcScalarPerVector,
428  BBlockTransferDstScalarPerVector_K1,
429  1,
430  1,
431  BThreadTransferSrcResetCoordinateAfterRun,
432  true,
433  NumGemmKPrefetchStage>(
434  b_grid_desc_bk0_n_bk1,
435  make_multi_index(0, n_block_data_idx_on_grid, 0),
436  b_element_op,
437  b_block_desc_bk0_n_bk1,
438  make_multi_index(0, 0, 0),
440 
441  // GEMM definition
442  // c_mtx += transpose(a_mtx) * b_mtx
443  // a_mtx[K0PerBlock, MPerBlock] is in LDS
444  // b_mtx[K0PerBlock, NPerBlock] is in LDS
445  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
446  // register
447  // sanity check
448  constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
449  constexpr bool is_single_rate_mfma =
451  lcm_AK1_BK1 <= 4) ||
452  (is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
454  lcm_AK1_BK1 < 32))
455  ? true
456  : false;
457  constexpr auto is_scale_mfma = false;
458  constexpr index_t k_pack = math::max(
459  lcm_AK1_BK1,
461  selected_mfma.k_per_blk);
462 
463  auto blockwise_gemm =
465  FloatAB,
466  FloatAB,
467  FloatAcc,
468  decltype(a_block_desc_ak0_m_ak1),
469  decltype(b_block_desc_bk0_n_bk1),
470  MPerXdl,
471  NPerXdl,
472  MXdlPerWave,
473  NXdlPerWave,
474  k_pack>{};
475 
476  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
477 
478  // LDS allocation for A and B: be careful of alignment
479  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
480  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
481 
482  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
483  static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
484 
485  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
486  static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
487  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
488 
489  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
490  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
491 
492  // gridwise GEMM pipeline
493  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
494  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
495  KPerBlock);
496 
497  GridwiseGemmPipe::template Run<HasMainK0BlockLoop>(a_grid_desc_ak0_m_ak1,
498  a_block_desc_ak0_m_ak1,
499  a_blockwise_copy,
500  a_grid_buf,
501  a_block_buf,
502  a_block_slice_copy_step,
503  b_grid_desc_bk0_n_bk1,
504  b_block_desc_bk0_n_bk1,
505  b_blockwise_copy,
506  b_grid_buf,
507  b_block_buf,
508  b_block_slice_copy_step,
509  blockwise_gemm,
510  c_thread_buf,
511  num_k_block_main_loop);
512 
513  // shuffle C and write out
514  {
515  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
516  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
517  "wrong!");
518 
519  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
520  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
521 
522  // TODO: hacky, fix it!
523  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
524  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
525 
526  // TODO: hacky, fix it!
527  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
528  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
529  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
530 
531  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
532  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
533  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
534  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
535  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
536  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
537  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
538  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
539 
540  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
542 
543  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
544  static_cast<FloatCShuffle*>(p_shared),
545  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
546  .GetElementSpaceSize());
547 
548  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
549  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
550  make_tuple(
551  make_freeze_transform(I0), // freeze mblock
553  Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per shuffle
555  make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
556  make_freeze_transform(I0), // freeze nblock
558  Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per shuffle
560  make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
562  Sequence<1>{},
563  Sequence<2>{},
564  Sequence<3>{},
565  Sequence<4>{},
566  Sequence<5>{}),
568  Sequence<0>{},
570  Sequence<>{},
571  Sequence<1>{},
572  Sequence<3, 7>{}));
573 
574  // calculate origin of thread output tensor on global memory
575  // blockwise GEMM c matrix starting index
576  const auto c_thread_mtx_on_block =
577  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
578 
579  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
580  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
581 
582  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
584  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
587 
588  const auto m_thread_data_on_block_idx =
589  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
590  make_multi_index(m_thread_data_on_block));
591 
592  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
597 
598  const auto n_thread_data_on_block_idx =
599  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
600  make_multi_index(n_thread_data_on_block));
601 
602  // VGPR to LDS
603  auto c_thread_copy_vgpr_to_lds =
605  FloatCShuffle,
606  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
607  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
609  Sequence<CShuffleMXdlPerWavePerShuffle,
610  CShuffleNXdlPerWavePerShuffle,
611  I1,
612  I1,
613  M2,
614  I1,
615  M4,
616  I1>,
618  7,
619  1,
621  1,
622  true>{
623  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
625  0,
626  m_thread_data_on_block_idx[I1],
627  n_thread_data_on_block_idx[I1],
628  m_thread_data_on_block_idx[I2],
629  m_thread_data_on_block_idx[I3],
630  m_thread_data_on_block_idx[I4],
631  n_thread_data_on_block_idx[I2]),
633 
634  // LDS to global
635  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
636  ThisThreadBlock, // ThreadGroup
637  CElementwiseOperation, // ElementwiseOperation,
638  CGlobalMemoryDataOperation, // DstInMemOp,
639  Sequence<1,
640  CShuffleMXdlPerWavePerShuffle,
641  MWave * MPerXdl,
642  1,
643  CShuffleNXdlPerWavePerShuffle,
644  NWave * NPerXdl>, // BlockSliceLengths,
645  CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
646  Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
647  FloatCShuffle, // typename SrcData,
648  FloatC, // typename DstData,
649  decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
650  decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
651  Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
652  5, // index_t VectorDim,
653  CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
654  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
655  false> // bool ThreadTransferDstResetCoordinateAfterRun>
656  {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
657  make_multi_index(0, 0, 0, 0, 0, 0),
658  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
659  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
660  c_element_op};
661 
662  constexpr auto mxdlperwave_forward_step =
663  make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0);
664  constexpr auto nxdlperwave_forward_step =
665  make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0);
666  constexpr auto nxdlperwave_backward_step =
667  make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0);
668 
669  static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) {
670  constexpr auto mxdlperwave = mxdlperwave_iter;
671 
672  static_for<0,
673  NXdlPerWave,
674  CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) {
675  constexpr bool nxdlperwave_forward_sweep =
676  (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
677 
678  constexpr index_t nxdlperwave_value =
679  nxdlperwave_forward_sweep
680  ? nxdlperwave_iter
681  : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
682 
683  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
684 
685  // make sure it's safe to do ds_write
686  block_sync_lds();
687 
688  // VGPR to LDS
689  c_thread_copy_vgpr_to_lds.Run(
690  c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
691  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
692  c_thread_buf,
693  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
694  c_shuffle_block_buf);
695 
696  // make sure it's safe to do ds_read
697  block_sync_lds();
698 
699  // LDS to global
700  c_block_copy_lds_to_global.Run(
701  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
702  c_shuffle_block_buf,
703  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
704  c_grid_buf);
705 
706  // move on nxdlperwave dimension
707  if constexpr(nxdlperwave_forward_sweep &&
708  (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
709  {
710  c_block_copy_lds_to_global.MoveDstSliceWindow(
711  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
712  nxdlperwave_forward_step);
713  }
714  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
715  {
716  c_block_copy_lds_to_global.MoveDstSliceWindow(
717  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
718  nxdlperwave_backward_step);
719  }
720  });
721 
722  // move on mxdlperwave dimension
723  if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
724  {
725  c_block_copy_lds_to_global.MoveDstSliceWindow(
726  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
727  mxdlperwave_forward_step);
728  }
729  });
730  }
731  }
732 };
733 
734 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
__global__ void kernel_gemm_xdlops_v3r1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r1.hpp:36
Definition: block_to_ctile_map.hpp:260
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v3r1.hpp:123
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r1.hpp:326
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v3r1.hpp:276
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdlops_v3r1.hpp:166
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v3r1.hpp:139
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v3r1.hpp:130
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition: gridwise_gemm_xdlops_v3r1.hpp:310
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v3r1.hpp:129
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v3r1.hpp:131
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v3r1.hpp:128
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r1.hpp:236
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v3r1.hpp:142
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v3r1.hpp:124
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r1.hpp:319
static constexpr auto AK0
Definition: gridwise_gemm_xdlops_v3r1.hpp:134
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v3r1.hpp:284
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition: gridwise_gemm_xdlops_v3r1.hpp:322
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl()
Definition: gridwise_gemm_xdlops_v3r1.hpp:189
static constexpr auto BK0
Definition: gridwise_gemm_xdlops_v3r1.hpp:135
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdlops_v3r1.hpp:144
static constexpr auto BK1
Definition: gridwise_gemm_xdlops_v3r1.hpp:137
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v3r1.hpp:125
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v3r1.hpp:127
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v3r1.hpp:207
static constexpr auto AK1
Definition: gridwise_gemm_xdlops_v3r1.hpp:136
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v3r1.hpp:126
Definition: xdlops_gemm.hpp:1126
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334