/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp Source File
gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename ABDataType,
21  typename FloatGemmAcc,
22  typename EDataTypeShuffle,
23  typename EDataType,
24  typename AElementwiseOperation,
25  typename BElementwiseOperation,
26  typename EElementwiseOperation,
27  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
28  typename AGridDesc_M_K,
29  typename BGridDesc_N_K,
30  typename EGridDesc_M_N,
31  index_t NumGemmKPrefetchStage,
32  index_t TileLoadThreadGroupSize,
33  index_t TileMathThreadGroupSize,
34  index_t MPerBlock,
35  index_t NPerBlock,
36  index_t KPerBlock,
37  index_t AK1Value,
38  index_t BK1Value,
39  index_t MPerXdl,
40  index_t NPerXdl,
41  index_t MXdlPerWave,
42  index_t NXdlPerWave,
43  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
44  typename ABlockTransferThreadClusterArrangeOrder,
45  typename ABlockTransferSrcAccessOrder,
46  index_t ABlockTransferSrcVectorDim,
47  index_t ABlockTransferSrcScalarPerVector,
48  index_t ABlockTransferDstScalarPerVector_AK1,
49  bool AThreadTransferSrcResetCoordinateAfterRun,
50  index_t ABlockLdsExtraM,
51  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
52  typename BBlockTransferThreadClusterArrangeOrder,
53  typename BBlockTransferSrcAccessOrder,
54  index_t BBlockTransferSrcVectorDim,
55  index_t BBlockTransferSrcScalarPerVector,
56  index_t BBlockTransferDstScalarPerVector_BK1,
57  bool BThreadTransferSrcResetCoordinateAfterRun,
58  index_t BBlockLdsExtraN,
59  index_t CShuffleMXdlPerWavePerShuffle,
60  index_t CShuffleNXdlPerWavePerShuffle,
61  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
62  index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
64 {
65 
66  static constexpr auto I0 = Number<0>{};
67  static constexpr auto I1 = Number<1>{};
68  static constexpr auto I2 = Number<2>{};
69  static constexpr auto I3 = Number<3>{};
70  static constexpr auto I4 = Number<4>{};
71  static constexpr auto I5 = Number<5>{};
72  static constexpr auto I6 = Number<6>{};
73  static constexpr auto I7 = Number<7>{};
74 
75  // K1 should be Number<...>
76  static constexpr auto AK1 = Number<AK1Value>{};
77  static constexpr auto BK1 = Number<BK1Value>{};
78  static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
79  static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
80  static constexpr auto BlockSize = math::max(TileLoadThreadGroupSize, TileMathThreadGroupSize);
81 
83  {
84  __device__ static constexpr index_t GetNumOfThread() { return TileLoadThreadGroupSize; }
85 
86  __device__ static constexpr bool IsBelong()
87  {
88  return (get_thread_local_1d_id() >= TileLoadThreadGroupSize);
89  }
90 
91  __device__ static index_t GetThreadId()
92  {
93  return get_thread_local_1d_id() - TileMathThreadGroupSize;
94  }
95  };
96 
98  {
99  __device__ static constexpr index_t GetNumOfThread() { return TileMathThreadGroupSize; }
100 
101  __device__ static constexpr bool IsBelong()
102  {
103  return get_thread_local_1d_id() < TileMathThreadGroupSize;
104  }
105 
106  __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
107  };
108 
110 
111  // load and math+store Wave pipelines.
112  // TODO: build pipelines blocks scheduling parallel tasks
115 
116  __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
117  {
118  // A matrix in LDS memory, dst of blockwise copy
122  }
123 
124  __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
125  {
126  // B matrix in LDS memory, dst of blockwise copy
130  }
131 
132  __host__ __device__ static constexpr auto
134  {
135  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
136  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
137 
138  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
140  make_tuple(I1,
142  I1,
144 
145  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
146  }
147 
148  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
149  {
150  // LDS allocation for A and B: be careful of alignment
151  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
152  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
153 
154  // lds max alignment
155  constexpr auto max_lds_align = math::lcm(AK1, BK1);
156 
157  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
158  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
159 
160  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
161  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
162 
163  // LDS allocation for C shuffle in LDS
164  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
166 
167  constexpr auto c_block_size =
168  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
169 
170  return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
171  sizeof(ABDataType),
172  c_block_size * sizeof(EDataTypeShuffle));
173  }
174 
175  template <
176  InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
177  __device__ static bool constexpr IsValidCompilationParameter()
178  {
179  return ck::tensor_operation::device::IsValidGemmCompilationParameter<
180  BlockSize,
181  MPerBlock,
182  NPerBlock,
183  MPerXdl,
184  NPerXdl,
185  MXdlPerWave,
186  NXdlPerWave,
187  EDataType,
188  CGlobalMemoryDataOperation>();
189  }
190 
191  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
192  template <typename Block2ETileMap>
193  __host__ __device__ static constexpr bool
194  CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
195  const BGridDesc_N_K& b_grid_desc_n_k,
196  const EGridDesc_M_N& e_grid_desc_m_n,
197  const Block2ETileMap& /*block_2_etile_map*/)
198  {
199  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
200  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
201  "Invalid tuning param!");
202 
203  const auto M = a_grid_desc_m_k.GetLength(I0);
204  const auto N = b_grid_desc_n_k.GetLength(I0);
205  const auto K = a_grid_desc_m_k.GetLength(I1);
206 
207  // check consistency of desc
208  if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
209  K == b_grid_desc_n_k.GetLength(I1)))
210  {
211  return false;
212  }
213 
214  // check tile size
215  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
216  {
217  return false;
218  }
219 
220  // check gridwise gemm pipeline
221  const auto num_k_loop = K / KPerBlock;
222 
223  if(!GridwiseGemmMath::IsSupported(num_k_loop))
224  {
225  return false;
226  }
227 
228  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
229 
230  // check tensor size: cannot be larger than 2GB each
231  constexpr long_index_t TwoGB = (long_index_t{1} << 31);
232 
233  if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
234  b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
235  e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
236  {
237  return false;
238  }
239 
240  return true;
241  }
242 
243  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
244  {
245  const index_t num_loop = K / KPerBlock;
246 
247  return GridwiseGemmMath::CalculateHasMainLoop(num_loop);
248  }
249 
250  // return block_id to E matrix tile idx (m0, n0) mapping
251  __host__ __device__ static constexpr auto
252  MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
253  {
254  const auto M = e_grid_desc_m_n.GetLength(I0);
255  const auto N = e_grid_desc_m_n.GetLength(I1);
256 
257  constexpr auto M1 = Number<MPerBlock>{};
258  constexpr auto N1 = Number<NPerBlock>{};
259 
260  const auto M0 = M / M1;
261  const auto N0 = N / N1;
262 
263  constexpr auto M01 = I1;
264  constexpr auto N01 = I1;
265 
266  const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
272 
273  const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
275  make_tuple(make_merge_transform(make_tuple(M0, N0, M01, N01))),
278 
279  const auto cblockid_to_m0_n0_block_cluster_adaptor =
280  chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
281  cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
282 
283  return cblockid_to_m0_n0_block_cluster_adaptor;
284  }
285 
286  __host__ __device__ static constexpr index_t
287  CalculateGridSize(const EGridDesc_M_N& e_grid_desc_m_n)
288  {
289  const auto M = e_grid_desc_m_n.GetLength(I0);
290  const auto N = e_grid_desc_m_n.GetLength(I1);
291 
292  const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
293 
294  return grid_size;
295  }
296 
297  // A desc for source in blockwise copy
298  __host__ __device__ static constexpr auto
299  MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
300  {
301  const auto M = a_grid_desc_m_k.GetLength(I0);
302  const auto K = a_grid_desc_m_k.GetLength(I1);
303 
304  const auto AK0 = K / AK1;
305 
306  return transform_tensor_descriptor(a_grid_desc_m_k,
311  }
312 
313  // B desc for source in blockwise copy
314  __host__ __device__ static constexpr auto
315  MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
316  {
317  const auto N = b_grid_desc_n_k.GetLength(I0);
318  const auto K = b_grid_desc_n_k.GetLength(I1);
319 
320  const auto BK0 = K / BK1;
321 
322  return transform_tensor_descriptor(b_grid_desc_n_k,
327  }
328 
329  // E desc for destination in blockwise copy
330  template <typename EGridDescriptor_M_N>
331  __host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
332  const EGridDescriptor_M_N& e_grid_desc_m_n)
333  {
334  const auto M = e_grid_desc_m_n.GetLength(I0);
335  const auto N = e_grid_desc_m_n.GetLength(I1);
336 
337  const auto MBlock = M / MPerBlock;
338  const auto NBlock = N / NPerBlock;
339 
340  const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
341  e_grid_desc_m_n,
346 
347  return e_grid_desc_mblock_mperblock_nblock_nperblock;
348  }
349 
352  EGridDesc_M_N{}))>;
353 
355  remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
356 
357  template <bool HasMainKBlockLoop,
358  typename AGridDesc_AK0_M_AK1,
359  typename BGridDesc_BK0_N_BK1,
360  typename Block2ETileMap>
361  __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
362  const ABDataType* __restrict__ p_b_grid,
363  EDataType* __restrict__ p_e_grid,
364  void* __restrict__ p_shared,
365  const AElementwiseOperation& a_element_op,
366  const BElementwiseOperation& b_element_op,
367  const EElementwiseOperation& e_element_op,
368  const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
369  const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
371  e_grid_desc_mblock_mperblock_nblock_nperblock,
372  const Block2ETileMap& block_2_etile_map)
373  {
374  // build loadWave and MathWave pipelines
375  // loadWave and MathWave synchronized through LDS
376 
377  // A matrix in LDS memory, dst of blockwise copy
378  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
379 
380  // B matrix in LDS memory, dst of blockwise copy
381  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
382 
383  // lds max alignment
384  constexpr auto max_lds_align = math::lcm(AK1, BK1);
385 
386  // LDS allocation for A and B: be careful of alignment
387  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
388  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
389 
390  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
391  static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
392 
393  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
394  static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
395  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
396 
397  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
398  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
399 
400  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
401  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
402  KPerBlock);
403 
404  // divide block work by [M, N]
405  const auto block_work_idx =
406  block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
407 
408  // HACK: this force m/n_block_data_idx_on_grid into SGPR
409  const index_t m_block_data_idx_on_grid =
410  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
411 
412  const index_t n_block_data_idx_on_grid =
413  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
414 
416  {
417 
418  // LoadWave
419  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
420  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
421  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
422  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
423 
424  // A matrix blockwise copy
425  auto a_blockwise_copy =
427  AElementwiseOperation,
431  ABlockTransferThreadClusterLengths_AK0_M_AK1,
432  ABlockTransferThreadClusterArrangeOrder,
433  ABDataType,
434  ABDataType,
435  decltype(a_grid_desc_ak0_m_ak1),
436  decltype(a_block_desc_ak0_m_ak1),
437  ABlockTransferSrcAccessOrder,
439  ABlockTransferSrcVectorDim,
440  2,
441  ABlockTransferSrcScalarPerVector,
442  ABlockTransferDstScalarPerVector_AK1,
443  1,
444  1,
445  AThreadTransferSrcResetCoordinateAfterRun,
446  true,
447  NumGemmKPrefetchStage>(
448  a_grid_desc_ak0_m_ak1,
449  make_multi_index(0, m_block_data_idx_on_grid, 0),
450  a_element_op,
451  a_block_desc_ak0_m_ak1,
452  make_multi_index(0, 0, 0),
454 
455  // B matrix blockwise copy
456  auto b_blockwise_copy =
458  BElementwiseOperation,
462  BBlockTransferThreadClusterLengths_BK0_N_BK1,
463  BBlockTransferThreadClusterArrangeOrder,
464  ABDataType,
465  ABDataType,
466  decltype(b_grid_desc_bk0_n_bk1),
467  decltype(b_block_desc_bk0_n_bk1),
468  BBlockTransferSrcAccessOrder,
470  BBlockTransferSrcVectorDim,
471  2,
472  BBlockTransferSrcScalarPerVector,
473  BBlockTransferDstScalarPerVector_BK1,
474  1,
475  1,
476  BThreadTransferSrcResetCoordinateAfterRun,
477  true,
478  NumGemmKPrefetchStage>(
479  b_grid_desc_bk0_n_bk1,
480  make_multi_index(0, n_block_data_idx_on_grid, 0),
481  b_element_op,
482  b_block_desc_bk0_n_bk1,
483  make_multi_index(0, 0, 0),
485 
486  GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(
487  a_grid_desc_ak0_m_ak1,
488  a_block_desc_ak0_m_ak1,
489  a_blockwise_copy,
490  a_grid_buf,
491  a_block_buf,
492  a_block_slice_copy_step,
493  b_grid_desc_bk0_n_bk1,
494  b_block_desc_bk0_n_bk1,
495  b_blockwise_copy,
496  b_grid_buf,
497  b_block_buf,
498  b_block_slice_copy_step,
499  num_k_block_main_loop);
500 
501  block_sync_lds();
502  block_sync_lds();
503  }
505  {
506  // branch early for math wave
507  constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
508  constexpr bool is_single_rate_mfma =
510  lcm_AK1_BK1 <= 4) ||
511  (is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
513  lcm_AK1_BK1 < 32))
514  ? true
515  : false;
516  constexpr auto is_scale_mfma = false;
517  constexpr index_t KPack =
518  math::max(lcm_AK1_BK1,
519  MfmaSelector<ABDataType,
520  MPerXdl,
521  NPerXdl,
522  ABDataType,
523  is_single_rate_mfma,
524  is_scale_mfma>::selected_mfma.k_per_blk);
525 
527  TileMathThreadGroupSize,
528  ABDataType,
529  ABDataType,
530  FloatGemmAcc,
531  decltype(a_block_desc_ak0_m_ak1),
532  decltype(b_block_desc_bk0_n_bk1),
533  MPerXdl,
534  NPerXdl,
535  MXdlPerWave,
536  NXdlPerWave,
537  KPack>{};
538 
539  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
540  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
541  p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
542 
543  // TODO re-architect LDS+math stages
544  // Writing data to GMEM: only math wave is doing the work in cshuffle
545  GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(
546  a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop);
547 
548  // GEMM definition
549  // c_mtx += transpose(a_mtx) * b_mtx
550  // a_mtx[K0PerBlock, MPerBlock] is in LDS
551  // b_mtx[K0PerBlock, NPerBlock] is in LDS
552  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
553  // register
554  // sanity check
555 
556  // shuffle C and write out
557  {
558  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
559  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
560  "wrong!");
561 
562  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
563  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
564 
565  // TODO: hacky, fix it!
566  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
567  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
568 
569  // TODO: hacky, fix it!
570  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
571  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
572  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
573 
574  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
575  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
576  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
577  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
578  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
579  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
580  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
581  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
582 
583  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
585 
586  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
587  static_cast<EDataTypeShuffle*>(p_shared),
588  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
589 
590  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
591  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
592  make_tuple(
595  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
596  M1, // M1 = MWave
597  M2, // M2 * M3 * M4 = MPerXdl
598  M3,
599  M4)),
602  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
603  N1, // N1 = NWave
604  N2))), // N2 = NPerXdl
608  Sequence<>{},
609  Sequence<1, 3, 7>{}));
610 
611  // calculate origin of thread output tensor on global memory
612  // blockwise GEMM c matrix starting index
613  const auto c_thread_mtx_on_block =
614  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
615 
616  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
617  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
618 
619  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
621  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
624 
625  const auto m_thread_data_on_block_idx =
626  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
627  make_multi_index(m_thread_data_on_block));
628 
629  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
634 
635  const auto n_thread_data_on_block_idx =
636  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
637  make_multi_index(n_thread_data_on_block));
638 
639  // shuffle: threadwise copy C from VGPR to LDS
640  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
641  FloatGemmAcc,
642  EDataTypeShuffle,
643  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
644  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
646  Sequence<CShuffleMXdlPerWavePerShuffle,
647  CShuffleNXdlPerWavePerShuffle,
648  I1,
649  I1,
650  M2,
651  I1,
652  M4,
653  I1>,
655  7,
656  1,
658  1,
659  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
661  0,
662  m_thread_data_on_block_idx[I1],
663  n_thread_data_on_block_idx[I1],
664  m_thread_data_on_block_idx[I2],
665  m_thread_data_on_block_idx[I3],
666  m_thread_data_on_block_idx[I4],
667  n_thread_data_on_block_idx[I2]),
669 
670  // shuffle: blockwise copy C from LDS to global
671  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
672  CShuffleBlockTransferThreadGroup, // ThreadGroup
673  EElementwiseOperation, // ElementwiseOperation,
674  CGlobalMemoryDataOperation, // DstInMemOp,
675  Sequence<1,
676  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
677  1,
678  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
679  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
680  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
681  EDataTypeShuffle, // typename SrcData,
682  EDataType, // typename DstData,
683  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
684  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
685  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
686  3, // index_t VectorDim,
687  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
688  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
689  false> // bool ThreadTransferDstResetCoordinateAfterRun>
690  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
691  make_multi_index(0, 0, 0, 0),
692  e_grid_desc_mblock_mperblock_nblock_nperblock,
693  make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
694  e_element_op};
695 
696  // space filling curve for threadwise C in VGPR
697  constexpr auto sfc_c_vgpr =
700  Sequence<CShuffleMXdlPerWavePerShuffle,
701  CShuffleNXdlPerWavePerShuffle,
702  1,
703  1,
704  M2,
705  1,
706  M4,
707  1>>{};
708 
709  // space filling curve for shuffled blockwise C in global mem
710  constexpr auto sfc_c_global =
713  Sequence<1,
714  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
715  1,
716  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
717 
718  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
719 
720  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
721 
722  // Different way of getting coalesced writes:
723  // We can get rid of doing cshuffle. Instead of reading A rows in contiguous manner
724  // do it interleaved, then mfma can have nice c-mat layout as below:
725  //
726  // TODO
727  // We do not need to do LDS swizzle to align global writes writing cache lines:
728  // v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN
729  // elments (N is vertical or strided
730  // dimension)
731  // v_mfma cmat, bmat, amat, cmat - c-mat register layout are Mx1
732  // elments (M is coalescing
733  // dimension) by enumerating M index in
734  // amat, bmat you can align cmat
735  // register(s) to contiguous M elements
736  // for example
737  // 1st mfma instruction output space : 0 4 8 12 16 ....
738  // 2nd mfma instruction output space : 1 5 9 13 17 ....
739  // 3rd mfma instruction output space : 2 6 10 14 18 ....
740  // 4th mfma instruction output space : 3 7 11 15 19 ....
741  // you can pack 4 registers output space into 2WORD and do global write
742  // (no LDS swizzling required)
743 
744  static_for<0, num_access, 1>{}([&](auto access_id) {
745  // make sure it's safe to write to LDS
746  block_sync_lds();
747 
748  // each thread write its data from VGPR to LDS
749  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
750  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
751  c_thread_buf,
752  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
753  c_shuffle_block_buf);
754  // make sure it's safe to read from LDS
755  block_sync_lds();
756 
757  // each block copy its data from LDS to global
758  c_shuffle_block_copy_lds_to_global.Run(
759  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
760  c_shuffle_block_buf,
761  e_grid_desc_mblock_mperblock_nblock_nperblock,
762  c_grid_buf);
763 
764  if constexpr(access_id < num_access - 1)
765  {
766  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
767 
768  // move on C
769  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
770  e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
771  }
772  });
773  }
774  }
775  }
776 };
777 
778 } // namespace ck
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition: tensor_adaptor.hpp:245
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_smfmac_xdlops.hpp:78
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:83
static __device__ index_t GetThreadId()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:91
static constexpr __device__ bool IsBelong()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:86
static constexpr __device__ index_t GetNumOfThread()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:84
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:98
static __device__ index_t GetThreadId()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:106
static constexpr __device__ index_t GetNumOfThread()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:99
static constexpr __device__ bool IsBelong()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:101
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:64
static constexpr auto BlockSize
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:80
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:252
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:133
static constexpr auto I3
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:69
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:148
static constexpr auto AK0PerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:78
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDescriptor_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:331
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:315
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:194
static constexpr auto BK0PerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:79
static constexpr auto I7
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:73
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:116
static constexpr auto I0
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:66
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:124
ThisThreadBlock< TileMathThreadGroupSize > CShuffleBlockTransferThreadGroup
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:109
remove_cvref_t< decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> DefaultBlock2ETileMap
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:355
static __device__ void Run(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const EElementwiseOperation &e_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:361
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:243
static constexpr auto AK1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:76
static constexpr auto I2
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:68
__host__ static constexpr __device__ index_t CalculateGridSize(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:287
remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:352
static constexpr auto I5
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:71
static constexpr auto I6
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:72
static constexpr auto I4
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:70
static constexpr auto BK1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:77
static constexpr auto I1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:67
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:177
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:299
Definition: gridwise_gemm_waveletmodel.hpp:11
Definition: gridwise_gemm_waveletmodel.hpp:103
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: thread_group.hpp:12
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334