/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp Source File
gridwise_gemm_xdlops_v3r2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm,
21  typename FloatAB,
22  typename FloatC,
23  typename AGridDesc_K0_M_K1,
24  typename BGridDesc_K0_N_K1,
25  typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
26  typename C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
27  typename AElementwiseOperation,
28  typename BElementwiseOperation,
29  typename CElementwiseOperation,
30  typename Block2CTileMap,
31  bool HasMainKBlockLoop>
32 __global__ void
33 #if CK_USE_LAUNCH_BOUNDS
35 #endif
37  const FloatAB* __restrict__ p_a_grid,
38  const FloatAB* __restrict__ p_b_grid,
39  FloatC* __restrict__ p_c_grid,
40  const FloatC* __restrict__ p_c0_grid,
41  const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
42  const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
43  const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
44  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
45  const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
46  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
47  const AElementwiseOperation a_element_op,
48  const BElementwiseOperation b_element_op,
49  const CElementwiseOperation c_element_op,
50  const Block2CTileMap block_2_ctile_map)
51 {
52 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
53  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
54 
55  GridwiseGemm::template Run<HasMainKBlockLoop>(
56  p_a_grid,
57  p_b_grid,
58  p_c_grid,
59  p_c0_grid,
60  p_shared,
61  a_grid_desc_k0_m_k1,
62  b_grid_desc_k0_n_k1,
63  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
64  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
65  a_element_op,
66  b_element_op,
67  c_element_op,
68  block_2_ctile_map);
69 #else
70  ignore = p_a_grid;
71  ignore = p_b_grid;
72  ignore = p_c_grid;
73  ignore = p_c0_grid;
74  ignore = a_grid_desc_k0_m_k1;
75  ignore = b_grid_desc_k0_n_k1;
76  ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
77  ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
78  ignore = a_element_op;
79  ignore = b_element_op;
80  ignore = c_element_op;
81  ignore = block_2_ctile_map;
82 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
83 }
84 
85 template <
86  index_t BlockSize,
87  typename FloatAB,
88  typename FloatAcc,
89  typename FloatC,
90  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
91  typename AGridDesc_K0_M_K1,
92  typename BGridDesc_K0_N_K1,
93  typename CGridDesc_M_N,
94  typename C0GridDesc_M_N,
95  typename AElementwiseOperation,
96  typename BElementwiseOperation,
97  typename CElementwiseOperation,
98  index_t MPerBlock,
99  index_t NPerBlock,
100  index_t K0PerBlock,
101  index_t MPerXdl,
102  index_t NPerXdl,
103  index_t K1Value,
104  index_t MXdlPerWave,
105  index_t NXdlPerWave,
106  typename ABlockTransferThreadClusterLengths_K0_M_K1,
107  typename ABlockTransferThreadClusterArrangeOrder,
108  typename ABlockTransferSrcAccessOrder,
109  index_t ABlockTransferSrcVectorDim,
110  index_t ABlockTransferSrcScalarPerVector,
111  index_t ABlockTransferDstScalarPerVector_K1,
112  bool AThreadTransferSrcResetCoordinateAfterRun,
113  bool ABlockLdsExtraM,
114  typename BBlockTransferThreadClusterLengths_K0_N_K1,
115  typename BBlockTransferThreadClusterArrangeOrder,
116  typename BBlockTransferSrcAccessOrder,
117  index_t BBlockTransferSrcVectorDim,
118  index_t BBlockTransferSrcScalarPerVector,
119  index_t BBlockTransferDstScalarPerVector_K1,
120  bool BThreadTransferSrcResetCoordinateAfterRun,
121  bool BBlockLdsExtraN,
122  index_t CShuffleMXdlPerWavePerShuffle,
123  index_t CShuffleNXdlPerWavePerShuffle,
124  typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
125  index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
126  index_t NumGemmKPrefetchStage = 1,
127  PipelineVersion PipelineVer = PipelineVersion::v1>
129 {
130  static constexpr auto I0 = Number<0>{};
131  static constexpr auto I1 = Number<1>{};
132  static constexpr auto I2 = Number<2>{};
133  static constexpr auto I3 = Number<3>{};
134  static constexpr auto I4 = Number<4>{};
135  static constexpr auto I5 = Number<5>{};
136  static constexpr auto I6 = Number<6>{};
137  static constexpr auto I7 = Number<7>{};
138 
139  // K1 should be Number<...>
140  static constexpr auto K1 = Number<K1Value>{};
141 
143 
145  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
146 
147  __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
148  {
149  constexpr auto max_lds_align = K1;
150 
151  // A matrix in LDS memory, dst of blockwise copy
152  constexpr auto a_block_desc_k0_m_k1 = [&]() {
153  if constexpr(ABlockLdsExtraM)
154  {
158  }
159  else
160  {
162  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
163  }
164  }();
165 
166  return a_block_desc_k0_m_k1;
167  }
168 
169  __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
170  {
171  constexpr auto max_lds_align = K1;
172 
173  // B matrix in LDS memory, dst of blockwise copy
174  constexpr auto b_block_desc_k0_n_k1 = [&]() {
175  if constexpr(BBlockLdsExtraN)
176  {
180  }
181  else
182  {
184  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
185  }
186  }();
187 
188  return b_block_desc_k0_n_k1;
189  }
190 
191  __host__ __device__ static constexpr auto
193  {
194  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
195  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
196 
197  constexpr auto
198  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
200  make_tuple(I1,
203  I1,
206 
207  return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
208  }
209 
210  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
211  {
212  // LDS allocation for A and B: be careful of alignment
213  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
214 
215  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
216 
217  constexpr auto max_lds_align = K1;
218 
219  constexpr auto a_block_space_size_aligned =
220  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
221 
222  constexpr auto b_block_space_size_aligned =
223  math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
224 
225  // LDS allocation for C shuffle in LDS
226  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
228 
229  constexpr auto c_block_size =
230  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
231  .GetElementSpaceSize();
232 
233  return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
234  sizeof(FloatAB),
235  c_block_size * sizeof(FloatC));
236  }
237 
238  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
239  template <typename Block2CTileMap>
240  __host__ __device__ static constexpr bool
241  CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
242  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
243  const CGridDesc_M_N& c_grid_desc_m_n,
244  const Block2CTileMap& block_2_ctile_map)
245  {
246  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
247  "wrong! K1 need to be known at compile-time");
248 
249  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
250  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
251  "Invalid tuning param!");
252 
253  const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
254  const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
255  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
256 
257  if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
258  K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
259  K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
260  return false;
261 
262  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
263  return false;
264 
265  // check gridwise gemm pipeline
266  const auto num_k_loop = K0 / K0PerBlock;
267 
268  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
269  {
270  return false;
271  }
272 
273  if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
274  {
275  return false;
276  }
277 
278  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
279  return true;
280  }
281 
282  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
283  {
284  const index_t num_loop = K / (K0PerBlock * K1);
285 
286  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
287  }
288 
289  template <typename CGridDesc_M_N_>
290  __host__ __device__ static constexpr auto
292  const CGridDesc_M_N_& c_grid_desc_m_n)
293  {
294  const auto M = c_grid_desc_m_n.GetLength(I0);
295  const auto N = c_grid_desc_m_n.GetLength(I1);
296 
297  const auto MBlock = M / MPerBlock;
298  const auto NBlock = N / NPerBlock;
299 
300  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
301  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
302 
303  const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
305  c_grid_desc_m_n,
312 
313  return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
314  }
315 
316  // return block_id to C matrix tile idx (m0, n0) mapping
317  __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
318  const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
319  {
321  c_grid_desc_m_n);
322  }
323 
327  CGridDesc_M_N{}))>;
328 
332  C0GridDesc_M_N{}))>;
333 
335  remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
336 
337  template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
338  __device__ static void
339  Run(const FloatAB* __restrict__ p_a_grid,
340  const FloatAB* __restrict__ p_b_grid,
341  FloatC* __restrict__ p_c_grid,
342  const FloatC* __restrict__ p_c0_grid,
343  void* __restrict__ p_shared,
344  const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
345  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
347  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
349  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
350  const AElementwiseOperation& a_element_op,
351  const BElementwiseOperation& b_element_op,
352  const CElementwiseOperation& c_element_op,
353  const Block2CTileMap& block_2_ctile_map)
354  {
355  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
356  p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
357  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
358  p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
359  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
360  p_c_grid,
361  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
362  .GetElementSpaceSize());
363  auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
364  p_c0_grid,
365  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
366  .GetElementSpaceSize());
367 
368  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
369 
370  // divide block work by [M, N]
371  const auto block_work_idx =
372  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
373 
374  if(!block_2_ctile_map.ValidCTileIndex(
375  block_work_idx,
376  make_tuple(
377  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
378  .GetLength(I0),
379  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
380  .GetLength(I3))))
381  {
382  return;
383  }
384 
385  // HACK: this force m/n_block_data_idx_on_grid into SGPR
386  const index_t m_block_data_idx_on_grid =
387  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
388 
389  const index_t n_block_data_idx_on_grid =
390  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
391 
392  // lds max alignment
393  constexpr auto max_lds_align = K1;
394 
395  // A matrix in LDS memory, dst of blockwise copy
396  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
397 
398  // B matrix in LDS memory, dst of blockwise copy
399  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
400 
401  // A matrix blockwise copy
402  auto a_blockwise_copy =
404  AElementwiseOperation,
408  ABlockTransferThreadClusterLengths_K0_M_K1,
409  ABlockTransferThreadClusterArrangeOrder,
410  FloatAB,
411  FloatAB,
412  decltype(a_grid_desc_k0_m_k1),
413  decltype(a_block_desc_k0_m_k1),
414  ABlockTransferSrcAccessOrder,
416  ABlockTransferSrcVectorDim,
417  2,
418  ABlockTransferSrcScalarPerVector,
419  ABlockTransferDstScalarPerVector_K1,
420  1,
421  1,
422  AThreadTransferSrcResetCoordinateAfterRun,
423  true,
424  NumGemmKPrefetchStage>(
425  a_grid_desc_k0_m_k1,
426  make_multi_index(0, m_block_data_idx_on_grid, 0),
427  a_element_op,
428  a_block_desc_k0_m_k1,
429  make_multi_index(0, 0, 0),
431 
432  // B matrix blockwise copy
433  auto b_blockwise_copy =
435  BElementwiseOperation,
439  BBlockTransferThreadClusterLengths_K0_N_K1,
440  BBlockTransferThreadClusterArrangeOrder,
441  FloatAB,
442  FloatAB,
443  decltype(b_grid_desc_k0_n_k1),
444  decltype(b_block_desc_k0_n_k1),
445  BBlockTransferSrcAccessOrder,
447  BBlockTransferSrcVectorDim,
448  2,
449  BBlockTransferSrcScalarPerVector,
450  BBlockTransferDstScalarPerVector_K1,
451  1,
452  1,
453  BThreadTransferSrcResetCoordinateAfterRun,
454  true,
455  NumGemmKPrefetchStage>(
456  b_grid_desc_k0_n_k1,
457  make_multi_index(0, n_block_data_idx_on_grid, 0),
458  b_element_op,
459  b_block_desc_k0_n_k1,
460  make_multi_index(0, 0, 0),
462 
463  // GEMM definition
464  // c_mtx += transpose(a_mtx) * b_mtx
465  // a_mtx[K0PerBlock, MPerBlock] is in LDS
466  // b_mtx[K0PerBlock, NPerBlock] is in LDS
467  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
468  // register
469  // sanity check
470 
471  auto blockwise_gemm =
473  FloatAB,
474  FloatAB,
475  FloatAcc,
476  decltype(a_block_desc_k0_m_k1),
477  decltype(b_block_desc_k0_n_k1),
478  MPerXdl,
479  NPerXdl,
480  MXdlPerWave,
481  NXdlPerWave,
482  K1>{};
483 
484  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
485 
486  // LDS allocation for A and B: be careful of alignment
487  constexpr auto a_block_space_size_aligned =
488  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
489 
490  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
491  static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
492 
493  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
494  static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
495  b_block_desc_k0_n_k1.GetElementSpaceSize());
496 
497  constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
498  constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
499 
500  // gridwise GEMM pipeline
501  const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
502 
503  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
504  a_block_desc_k0_m_k1,
505  a_blockwise_copy,
506  a_grid_buf,
507  a_block_buf,
508  a_block_slice_copy_step,
509  b_grid_desc_k0_n_k1,
510  b_block_desc_k0_n_k1,
511  b_blockwise_copy,
512  b_grid_buf,
513  b_block_buf,
514  b_block_slice_copy_step,
515  blockwise_gemm,
516  c_thread_buf,
517  K0BlockMainLoop);
518 
519  // shuffle C and write out
520  {
521  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
522  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
523  "wrong!");
524 
525  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
526  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
527 
528  // TODO: hacky, fix it!
529  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
530  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
531 
532  // TODO: hacky, fix it!
533  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
534  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
535  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
536 
537  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
538  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
539  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
540  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
541  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
542  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
543  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
544  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
545 
546  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
548 
549  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
550  static_cast<FloatC*>(p_shared),
551  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
552  .GetElementSpaceSize());
553 
554  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
555  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
556  make_tuple(
557  make_freeze_transform(I0), // freeze mblock
559  Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per shuffle
561  make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
562  make_freeze_transform(I0), // freeze nblock
564  Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per shuffle
566  make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
568  Sequence<1>{},
569  Sequence<2>{},
570  Sequence<3>{},
571  Sequence<4>{},
572  Sequence<5>{}),
574  Sequence<0>{},
576  Sequence<>{},
577  Sequence<1>{},
578  Sequence<3, 7>{})
579 
580  );
581 
582  // calculate origin of thread output tensor on global memory
583  // blockwise GEMM c matrix starting index
584  const auto c_thread_mtx_on_block =
585  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
586 
587  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
588  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
589 
590  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
592  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
595 
596  const auto m_thread_data_on_block_idx =
597  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
598  make_multi_index(m_thread_data_on_block));
599 
600  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
605 
606  const auto n_thread_data_on_block_idx =
607  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
608  make_multi_index(n_thread_data_on_block));
609 
610  // VGPR to LDS
611  auto c_thread_copy_vgpr_to_lds =
613  FloatC,
614  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
615  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
617  Sequence<CShuffleMXdlPerWavePerShuffle,
618  CShuffleNXdlPerWavePerShuffle,
619  I1,
620  I1,
621  M2,
622  I1,
623  M4,
624  I1>,
626  7,
627  1,
629  1,
630  true>{
631  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
633  0,
634  m_thread_data_on_block_idx[I1],
635  n_thread_data_on_block_idx[I1],
636  m_thread_data_on_block_idx[I2],
637  m_thread_data_on_block_idx[I3],
638  m_thread_data_on_block_idx[I4],
639  n_thread_data_on_block_idx[I2]),
641 
642  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r2<
643  ThisThreadBlock, // index_t BlockSize,
644  CElementwiseOperation, // ElementwiseOperation,
645  CGlobalMemoryDataOperation, // DstInMemOp,
646  Sequence<1,
647  CShuffleMXdlPerWavePerShuffle,
648  MWave * MPerXdl,
649  1,
650  CShuffleNXdlPerWavePerShuffle,
651  NWave * NPerXdl>, // BlockSliceLengths,
652  CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
653  Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
654  FloatC, // typename Src0Data,
655  FloatC, // typename Src1Data,
656  FloatC, // typename DstData,
657  decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
658  decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
659  decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
660  Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
661  5, // index_t VectorDim,
662  CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
663  true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
664  false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
665  false> // bool ThreadTransferDstResetCoordinateAfterRun>
666  {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
667  make_multi_index(0, 0, 0, 0, 0, 0),
668  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
669  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
670  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
671  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
672  c_element_op};
673 
674  constexpr auto mxdlperwave_forward_step =
675  make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0);
676  constexpr auto nxdlperwave_forward_step =
677  make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0);
678  constexpr auto nxdlperwave_backward_step =
679  make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0);
680 
681  static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) {
682  constexpr auto mxdlperwave = mxdlperwave_iter;
683 
684  static_for<0,
685  NXdlPerWave,
686  CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) {
687  constexpr bool nxdlperwave_forward_sweep =
688  (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
689 
690  constexpr index_t nxdlperwave_value =
691  nxdlperwave_forward_sweep
692  ? nxdlperwave_iter
693  : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
694 
695  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
696 
697  // make sure it's safe to do ds_write
698  block_sync_lds();
699 
700  // VGPR to LDS
701  c_thread_copy_vgpr_to_lds.Run(
702  c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
703  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
704  c_thread_buf,
705  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
706  c_block_buf);
707 
708  // make sure it's safe to do ds_read
709  block_sync_lds();
710 
711  // LDS to global
712  c_block_copy_lds_to_global.Run(
713  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
714  c_block_buf,
715  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
716  c0_grid_buf,
717  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
718  c_grid_buf);
719 
720  // move on nxdlperwave dimension
721  if constexpr(nxdlperwave_forward_sweep &&
722  (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
723  {
724  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
725  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
726  nxdlperwave_forward_step);
727 
728  c_block_copy_lds_to_global.MoveDstSliceWindow(
729  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
730  nxdlperwave_forward_step);
731  }
732  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
733  {
734  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
735  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
736  nxdlperwave_backward_step);
737 
738  c_block_copy_lds_to_global.MoveDstSliceWindow(
739  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
740  nxdlperwave_backward_step);
741  }
742  });
743 
744  // move on mxdlperwave dimension
745  if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
746  {
747  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
748  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
749  mxdlperwave_forward_step);
750 
751  c_block_copy_lds_to_global.MoveDstSliceWindow(
752  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
753  mxdlperwave_forward_step);
754  }
755  });
756  }
757  }
758 };
759 
760 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__global__ void kernel_gemm_xdlops_v3r2(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC *__restrict__ p_c0_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r2.hpp:36
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: block_to_ctile_map.hpp:260
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v3r2.hpp:129
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v3r2.hpp:131
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v3r2.hpp:282
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(C0GridDesc_M_N{}))> C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r2.hpp:332
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v3r2.hpp:140
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v3r2.hpp:210
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_v3r2.hpp:147
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_v3r2.hpp:169
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v3r2.hpp:133
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v3r2.hpp:134
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v3r2.hpp:291
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition: gridwise_gemm_xdlops_v3r2.hpp:317
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v3r2.hpp:130
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v3r2.hpp:145
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC *__restrict__ p_c0_grid, void *__restrict__ p_shared, const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r2.hpp:339
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v3r2.hpp:132
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v3r2.hpp:142
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r2.hpp:241
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v3r2.hpp:135
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v3r2.hpp:136
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl()
Definition: gridwise_gemm_xdlops_v3r2.hpp:192
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v3r2.hpp:137
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r2.hpp:327
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition: gridwise_gemm_xdlops_v3r2.hpp:335
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r2.hpp:37
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334