/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp Source File
gridwise_gemm_xdlops_v3r1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
18 
19 namespace ck {
20 
21 template <typename GridwiseGemm,
22  typename FloatAB,
23  typename FloatC,
24  typename AGridDesc_AK0_M_AK1,
25  typename BGridDesc_BK0_N_BK1,
26  typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
27  typename AElementwiseOperation,
28  typename BElementwiseOperation,
29  typename CElementwiseOperation,
30  typename Block2CTileMap,
31  bool HasMainK0BlockLoop>
32 __global__ void
33 #if CK_USE_LAUNCH_BOUNDS
35 #endif
37  const FloatAB* __restrict__ p_a_grid,
38  const FloatAB* __restrict__ p_b_grid,
39  FloatC* __restrict__ p_c_grid,
40  const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
41  const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
42  const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
43  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
44  const AElementwiseOperation a_element_op,
45  const BElementwiseOperation b_element_op,
46  const CElementwiseOperation c_element_op,
47  const Block2CTileMap block_2_ctile_map)
48 {
49 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
50  defined(__gfx12__)
51  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
52  {
53  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
54 
55  GridwiseGemm::template Run<HasMainK0BlockLoop>(
56  p_a_grid,
57  p_b_grid,
58  p_c_grid,
59  p_shared,
60  a_grid_desc_ak0_m_ak1,
61  b_grid_desc_bk0_n_bk1,
62  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
63  a_element_op,
64  b_element_op,
65  c_element_op,
66  block_2_ctile_map);
67  }
68 #else
69  ignore = p_a_grid;
70  ignore = p_b_grid;
71  ignore = p_c_grid;
72  ignore = a_grid_desc_ak0_m_ak1;
73  ignore = b_grid_desc_bk0_n_bk1;
74  ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
75  ignore = a_element_op;
76  ignore = b_element_op;
77  ignore = c_element_op;
78  ignore = block_2_ctile_map;
79 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
80 }
81 
82 template <
83  index_t BlockSize,
84  typename FloatAB,
85  typename FloatAcc,
86  typename FloatCShuffle,
87  typename FloatC,
88  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
89  typename AGridDesc_AK0_M_AK1,
90  typename BGridDesc_BK0_N_BK1,
91  typename CGridDesc_M_N,
92  typename AElementwiseOperation,
93  typename BElementwiseOperation,
94  typename CElementwiseOperation,
95  index_t MPerBlock,
96  index_t NPerBlock,
97  index_t KPerBlock,
98  index_t AK1Value,
99  index_t BK1Value,
100  index_t MPerXdl,
101  index_t NPerXdl,
102  index_t MXdlPerWave,
103  index_t NXdlPerWave,
104  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
105  typename ABlockTransferThreadClusterArrangeOrder,
106  typename ABlockTransferSrcAccessOrder,
107  index_t ABlockTransferSrcVectorDim,
108  index_t ABlockTransferSrcScalarPerVector,
109  index_t ABlockTransferDstScalarPerVector_K1,
110  bool AThreadTransferSrcResetCoordinateAfterRun,
111  bool ABlockLdsExtraM,
112  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
113  typename BBlockTransferThreadClusterArrangeOrder,
114  typename BBlockTransferSrcAccessOrder,
115  index_t BBlockTransferSrcVectorDim,
116  index_t BBlockTransferSrcScalarPerVector,
117  index_t BBlockTransferDstScalarPerVector_K1,
118  bool BThreadTransferSrcResetCoordinateAfterRun,
119  bool BBlockLdsExtraN,
120  index_t CShuffleMXdlPerWavePerShuffle,
121  index_t CShuffleNXdlPerWavePerShuffle,
122  typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
123  index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
124  index_t NumGemmKPrefetchStage = 1,
125  PipelineVersion PipelineVer = PipelineVersion::v1>
127 {
128  static constexpr auto I0 = Number<0>{};
129  static constexpr auto I1 = Number<1>{};
130  static constexpr auto I2 = Number<2>{};
131  static constexpr auto I3 = Number<3>{};
132  static constexpr auto I4 = Number<4>{};
133  static constexpr auto I5 = Number<5>{};
134  static constexpr auto I6 = Number<6>{};
135  static constexpr auto I7 = Number<7>{};
136 
137  // K1 should be Number<...>
138  static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
139  static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
140  static constexpr auto AK1 = Number<AK1Value>{};
141  static constexpr auto BK1 = Number<BK1Value>{};
142 
144 
146  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
147 
148  __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
149  {
150  constexpr auto max_lds_align = AK1;
151 
152  // A matrix in LDS memory, dst of blockwise copy
153  constexpr auto a_block_desc_ak0_m_ak1 = [&]() {
154  if constexpr(ABlockLdsExtraM)
155  {
159  }
160  else
161  {
163  make_tuple(AK0, Number<MPerBlock>{}, AK1), max_lds_align);
164  }
165  }();
166 
167  return a_block_desc_ak0_m_ak1;
168  }
169 
170  __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
171  {
172  constexpr auto max_lds_align = BK1;
173 
174  // B matrix in LDS memory, dst of blockwise copy
175  constexpr auto b_block_desc_bk0_n_bk1 = [&]() {
176  if constexpr(BBlockLdsExtraN)
177  {
181  }
182  else
183  {
185  make_tuple(BK0, Number<NPerBlock>{}, BK1), max_lds_align);
186  }
187  }();
188 
189  return b_block_desc_bk0_n_bk1;
190  }
191 
192  __host__ __device__ static constexpr auto
194  {
195  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
196  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
197 
198  constexpr auto
199  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
201  make_tuple(I1,
204  I1,
207 
208  return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
209  }
210 
211  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
212  {
213  // LDS allocation for A and B: be careful of alignment
214  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
215 
216  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
217 
218  constexpr auto a_block_space_size_aligned =
219  math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), AK1);
220 
221  constexpr auto b_block_space_size_aligned =
222  math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), BK1);
223 
224  // LDS allocation for C shuffle in LDS
225  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
227 
228  constexpr auto c_block_size =
229  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
230  .GetElementSpaceSize();
231 
232  return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
233  sizeof(FloatAB),
234  c_block_size * sizeof(FloatCShuffle));
235  }
236 
237  template <
238  InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
239  __device__ static bool constexpr IsValidCompilationParameter()
240  {
241  return ck::tensor_operation::device::IsValidGemmCompilationParameter<
242  BlockSize,
243  MPerBlock,
244  NPerBlock,
245  MPerXdl,
246  NPerXdl,
247  MXdlPerWave,
248  NXdlPerWave,
249  FloatC,
250  CGlobalMemoryDataOperation>();
251  }
252 
253  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
254  template <typename Block2CTileMap>
255  __host__ __device__ static constexpr bool
256  CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
257  const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
258  const CGridDesc_M_N& c_grid_desc_m_n,
259  const Block2CTileMap& block_2_ctile_map)
260  {
261  // static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
262  // is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
263  // "wrong! K1 need to be known at compile-time");
264 
265  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
266  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
267  "Invalid tuning param!");
268 
269  const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
270  const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
271  const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
272 
273  if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
274  return false;
275 
276  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
277  return false;
278 
279  // check gridwise gemm pipeline
280  const auto num_k_loop = K / KPerBlock;
281 
282  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
283  {
284  return false;
285  }
286 
287  if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
288  {
289  return false;
290  }
291 
292  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
293  return true;
294  }
295 
296  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
297  {
298  const index_t num_loop = K / KPerBlock;
299 
300  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
301  }
302 
303  __host__ __device__ static constexpr auto
305  const CGridDesc_M_N& c_grid_desc_m_n)
306  {
307  const auto M = c_grid_desc_m_n.GetLength(I0);
308  const auto N = c_grid_desc_m_n.GetLength(I1);
309 
310  const auto MBlock = M / MPerBlock;
311  const auto NBlock = N / NPerBlock;
312 
313  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
314  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
315 
316  const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
318  c_grid_desc_m_n,
325 
326  return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
327  }
328 
329  // return block_id to C matrix tile idx (m0, n0) mapping
330  __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
331  const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
332  {
334  c_grid_desc_m_n);
335  }
339  CGridDesc_M_N{}))>;
340 
342  remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
343 
344  template <bool HasMainK0BlockLoop, typename Block2CTileMap>
345  __device__ static void
346  Run(const FloatAB* __restrict__ p_a_grid,
347  const FloatAB* __restrict__ p_b_grid,
348  FloatC* __restrict__ p_c_grid,
349  void* __restrict__ p_shared,
350  const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
351  const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
353  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
354  const AElementwiseOperation& a_element_op,
355  const BElementwiseOperation& b_element_op,
356  const CElementwiseOperation& c_element_op,
357  const Block2CTileMap& block_2_ctile_map)
358  {
359  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
360  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
361  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
362  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
363  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
364  p_c_grid,
365  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
366  .GetElementSpaceSize());
367 
368  // divide block work by [M, N]
369  const auto block_work_idx =
370  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
371 
372  if(!block_2_ctile_map.ValidCTileIndex(
373  block_work_idx,
374  make_tuple(
375  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
376  .GetLength(I0),
377  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
378  .GetLength(I3))))
379  {
380  return;
381  }
382 
383  // HACK: this force m/n_block_data_idx_on_grid into SGPR
384  const index_t m_block_data_idx_on_grid =
385  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
386 
387  const index_t n_block_data_idx_on_grid =
388  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
389 
390  // lds max alignment
391  constexpr auto max_lds_align = math::lcm(AK1, BK1);
392 
393  // A matrix in LDS memory, dst of blockwise copy
394  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
395 
396  // B matrix in LDS memory, dst of blockwise copy
397  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
398 
399  // A matrix blockwise copy
400  auto a_blockwise_copy =
402  AElementwiseOperation,
406  ABlockTransferThreadClusterLengths_AK0_M_AK1,
407  ABlockTransferThreadClusterArrangeOrder,
408  FloatAB,
409  FloatAB,
410  decltype(a_grid_desc_ak0_m_ak1),
411  decltype(a_block_desc_ak0_m_ak1),
412  ABlockTransferSrcAccessOrder,
414  ABlockTransferSrcVectorDim,
415  2,
416  ABlockTransferSrcScalarPerVector,
417  ABlockTransferDstScalarPerVector_K1,
418  1,
419  1,
420  AThreadTransferSrcResetCoordinateAfterRun,
421  true,
422  NumGemmKPrefetchStage>(
423  a_grid_desc_ak0_m_ak1,
424  make_multi_index(0, m_block_data_idx_on_grid, 0),
425  a_element_op,
426  a_block_desc_ak0_m_ak1,
427  make_multi_index(0, 0, 0),
429 
430  // B matrix blockwise copy
431  auto b_blockwise_copy =
433  BElementwiseOperation,
437  BBlockTransferThreadClusterLengths_BK0_N_BK1,
438  BBlockTransferThreadClusterArrangeOrder,
439  FloatAB,
440  FloatAB,
441  decltype(b_grid_desc_bk0_n_bk1),
442  decltype(b_block_desc_bk0_n_bk1),
443  BBlockTransferSrcAccessOrder,
445  BBlockTransferSrcVectorDim,
446  2,
447  BBlockTransferSrcScalarPerVector,
448  BBlockTransferDstScalarPerVector_K1,
449  1,
450  1,
451  BThreadTransferSrcResetCoordinateAfterRun,
452  true,
453  NumGemmKPrefetchStage>(
454  b_grid_desc_bk0_n_bk1,
455  make_multi_index(0, n_block_data_idx_on_grid, 0),
456  b_element_op,
457  b_block_desc_bk0_n_bk1,
458  make_multi_index(0, 0, 0),
460 
461  // GEMM definition
462  // c_mtx += transpose(a_mtx) * b_mtx
463  // a_mtx[K0PerBlock, MPerBlock] is in LDS
464  // b_mtx[K0PerBlock, NPerBlock] is in LDS
465  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
466  // register
467  // sanity check
468  constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
469  constexpr bool is_single_rate_mfma =
471  lcm_AK1_BK1 <= 4) ||
472  (is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
474  lcm_AK1_BK1 < 32))
475  ? true
476  : false;
477  constexpr auto is_scale_mfma = false;
478  constexpr index_t k_pack = math::max(
479  lcm_AK1_BK1,
481  selected_mfma.k_per_blk);
482 
483  auto blockwise_gemm =
485  FloatAB,
486  FloatAB,
487  FloatAcc,
488  decltype(a_block_desc_ak0_m_ak1),
489  decltype(b_block_desc_bk0_n_bk1),
490  MPerXdl,
491  NPerXdl,
492  MXdlPerWave,
493  NXdlPerWave,
494  k_pack>{};
495 
496  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
497 
498  // LDS allocation for A and B: be careful of alignment
499  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
500  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
501 
502  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
503  static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
504 
505  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
506  static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
507  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
508 
509  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
510  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
511 
512  // gridwise GEMM pipeline
513  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
514  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
515  KPerBlock);
516 
517  GridwiseGemmPipe::template Run<HasMainK0BlockLoop>(a_grid_desc_ak0_m_ak1,
518  a_block_desc_ak0_m_ak1,
519  a_blockwise_copy,
520  a_grid_buf,
521  a_block_buf,
522  a_block_slice_copy_step,
523  b_grid_desc_bk0_n_bk1,
524  b_block_desc_bk0_n_bk1,
525  b_blockwise_copy,
526  b_grid_buf,
527  b_block_buf,
528  b_block_slice_copy_step,
529  blockwise_gemm,
530  c_thread_buf,
531  num_k_block_main_loop);
532 
533  // shuffle C and write out
534  {
535  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
536  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
537  "wrong!");
538 
539  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
540  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
541 
542  // TODO: hacky, fix it!
543  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
544  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
545 
546  // TODO: hacky, fix it!
547  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
548  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
549  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
550 
551  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
552  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
553  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
554  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
555  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
556  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
557  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
558  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
559 
560  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
562 
563  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
564  static_cast<FloatCShuffle*>(p_shared),
565  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
566  .GetElementSpaceSize());
567 
568  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
569  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
570  make_tuple(
571  make_freeze_transform(I0), // freeze mblock
573  Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per shuffle
575  make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
576  make_freeze_transform(I0), // freeze nblock
578  Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per shuffle
580  make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
582  Sequence<1>{},
583  Sequence<2>{},
584  Sequence<3>{},
585  Sequence<4>{},
586  Sequence<5>{}),
588  Sequence<0>{},
590  Sequence<>{},
591  Sequence<1>{},
592  Sequence<3, 7>{}));
593 
594  // calculate origin of thread output tensor on global memory
595  // blockwise GEMM c matrix starting index
596  const auto c_thread_mtx_on_block =
597  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
598 
599  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
600  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
601 
602  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
604  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
607 
608  const auto m_thread_data_on_block_idx =
609  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
610  make_multi_index(m_thread_data_on_block));
611 
612  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
617 
618  const auto n_thread_data_on_block_idx =
619  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
620  make_multi_index(n_thread_data_on_block));
621 
622  // VGPR to LDS
623  auto c_thread_copy_vgpr_to_lds =
625  FloatCShuffle,
626  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
627  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
629  Sequence<CShuffleMXdlPerWavePerShuffle,
630  CShuffleNXdlPerWavePerShuffle,
631  I1,
632  I1,
633  M2,
634  I1,
635  M4,
636  I1>,
638  7,
639  1,
641  1,
642  true>{
643  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
645  0,
646  m_thread_data_on_block_idx[I1],
647  n_thread_data_on_block_idx[I1],
648  m_thread_data_on_block_idx[I2],
649  m_thread_data_on_block_idx[I3],
650  m_thread_data_on_block_idx[I4],
651  n_thread_data_on_block_idx[I2]),
653 
654  // LDS to global
655  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
656  ThisThreadBlock, // ThreadGroup
657  CElementwiseOperation, // ElementwiseOperation,
658  CGlobalMemoryDataOperation, // DstInMemOp,
659  Sequence<1,
660  CShuffleMXdlPerWavePerShuffle,
661  MWave * MPerXdl,
662  1,
663  CShuffleNXdlPerWavePerShuffle,
664  NWave * NPerXdl>, // BlockSliceLengths,
665  CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
666  Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
667  FloatCShuffle, // typename SrcData,
668  FloatC, // typename DstData,
669  decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
670  decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
671  Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
672  5, // index_t VectorDim,
673  CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
674  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
675  false> // bool ThreadTransferDstResetCoordinateAfterRun>
676  {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
677  make_multi_index(0, 0, 0, 0, 0, 0),
678  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
679  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
680  c_element_op};
681 
682  constexpr auto mxdlperwave_forward_step =
683  make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0);
684  constexpr auto nxdlperwave_forward_step =
685  make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0);
686  constexpr auto nxdlperwave_backward_step =
687  make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0);
688 
689  static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) {
690  constexpr auto mxdlperwave = mxdlperwave_iter;
691 
692  static_for<0,
693  NXdlPerWave,
694  CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) {
695  constexpr bool nxdlperwave_forward_sweep =
696  (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
697 
698  constexpr index_t nxdlperwave_value =
699  nxdlperwave_forward_sweep
700  ? nxdlperwave_iter
701  : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
702 
703  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
704 
705  // make sure it's safe to do ds_write
706  block_sync_lds();
707 
708  // VGPR to LDS
709  c_thread_copy_vgpr_to_lds.Run(
710  c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
711  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
712  c_thread_buf,
713  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
714  c_shuffle_block_buf);
715 
716  // make sure it's safe to do ds_read
717  block_sync_lds();
718 
719  // LDS to global
720  c_block_copy_lds_to_global.Run(
721  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
722  c_shuffle_block_buf,
723  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
724  c_grid_buf);
725 
726  // move on nxdlperwave dimension
727  if constexpr(nxdlperwave_forward_sweep &&
728  (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
729  {
730  c_block_copy_lds_to_global.MoveDstSliceWindow(
731  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
732  nxdlperwave_forward_step);
733  }
734  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
735  {
736  c_block_copy_lds_to_global.MoveDstSliceWindow(
737  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
738  nxdlperwave_backward_step);
739  }
740  });
741 
742  // move on mxdlperwave dimension
743  if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
744  {
745  c_block_copy_lds_to_global.MoveDstSliceWindow(
746  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
747  mxdlperwave_forward_step);
748  }
749  });
750  }
751  }
752 };
753 
754 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
__global__ void kernel_gemm_xdlops_v3r1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r1.hpp:36
Definition: block_to_ctile_map.hpp:261
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v3r1.hpp:127
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r1.hpp:346
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v3r1.hpp:296
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdlops_v3r1.hpp:170
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v3r1.hpp:143
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v3r1.hpp:134
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition: gridwise_gemm_xdlops_v3r1.hpp:330
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v3r1.hpp:133
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v3r1.hpp:135
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdlops_v3r1.hpp:239
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v3r1.hpp:132
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r1.hpp:256
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v3r1.hpp:146
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v3r1.hpp:128
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r1.hpp:339
static constexpr auto AK0
Definition: gridwise_gemm_xdlops_v3r1.hpp:138
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v3r1.hpp:304
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition: gridwise_gemm_xdlops_v3r1.hpp:342
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl()
Definition: gridwise_gemm_xdlops_v3r1.hpp:193
static constexpr auto BK0
Definition: gridwise_gemm_xdlops_v3r1.hpp:139
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdlops_v3r1.hpp:148
static constexpr auto BK1
Definition: gridwise_gemm_xdlops_v3r1.hpp:141
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v3r1.hpp:129
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v3r1.hpp:131
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v3r1.hpp:211
static constexpr auto AK1
Definition: gridwise_gemm_xdlops_v3r1.hpp:140
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v3r1.hpp:130
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334