/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp Source File
gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename ABDataType,
21  typename FloatGemmAcc,
22  typename EDataTypeShuffle,
23  typename EDataType,
24  typename AElementwiseOperation,
25  typename BElementwiseOperation,
26  typename EElementwiseOperation,
27  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
28  typename AGridDesc_M_K,
29  typename BGridDesc_N_K,
30  typename EGridDesc_M_N,
31  index_t NumGemmKPrefetchStage,
32  index_t TileLoadThreadGroupSize,
33  index_t TileMathThreadGroupSize,
34  index_t MPerBlock,
35  index_t NPerBlock,
36  index_t KPerBlock,
37  index_t AK1Value,
38  index_t BK1Value,
39  index_t MPerXdl,
40  index_t NPerXdl,
41  index_t MXdlPerWave,
42  index_t NXdlPerWave,
43  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
44  typename ABlockTransferThreadClusterArrangeOrder,
45  typename ABlockTransferSrcAccessOrder,
46  index_t ABlockTransferSrcVectorDim,
47  index_t ABlockTransferSrcScalarPerVector,
48  index_t ABlockTransferDstScalarPerVector_AK1,
49  bool AThreadTransferSrcResetCoordinateAfterRun,
50  index_t ABlockLdsExtraM,
51  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
52  typename BBlockTransferThreadClusterArrangeOrder,
53  typename BBlockTransferSrcAccessOrder,
54  index_t BBlockTransferSrcVectorDim,
55  index_t BBlockTransferSrcScalarPerVector,
56  index_t BBlockTransferDstScalarPerVector_BK1,
57  bool BThreadTransferSrcResetCoordinateAfterRun,
58  index_t BBlockLdsExtraN,
59  index_t CShuffleMXdlPerWavePerShuffle,
60  index_t CShuffleNXdlPerWavePerShuffle,
61  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
62  index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
64 {
65 
66  static constexpr auto I0 = Number<0>{};
67  static constexpr auto I1 = Number<1>{};
68  static constexpr auto I2 = Number<2>{};
69  static constexpr auto I3 = Number<3>{};
70  static constexpr auto I4 = Number<4>{};
71  static constexpr auto I5 = Number<5>{};
72  static constexpr auto I6 = Number<6>{};
73  static constexpr auto I7 = Number<7>{};
74 
75  // K1 should be Number<...>
76  static constexpr auto AK1 = Number<AK1Value>{};
77  static constexpr auto BK1 = Number<BK1Value>{};
78  static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
79  static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
80 
82  {
83  __device__ static constexpr index_t GetNumOfThread() { return TileLoadThreadGroupSize; }
84 
85  __device__ static constexpr bool IsBelong()
86  {
87  return (get_thread_local_1d_id() >= TileLoadThreadGroupSize);
88  }
89 
90  __device__ static index_t GetThreadId()
91  {
92  return get_thread_local_1d_id() - TileMathThreadGroupSize;
93  }
94  };
95 
97  {
98  __device__ static constexpr index_t GetNumOfThread() { return TileMathThreadGroupSize; }
99 
100  __device__ static constexpr bool IsBelong()
101  {
102  return get_thread_local_1d_id() < TileMathThreadGroupSize;
103  }
104 
105  __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
106  };
107 
109 
110  // load and math+store Wave pipelines.
111  // TODO: build pipelines blocks scheduling parallel tasks
114 
115  __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
116  {
117  // A matrix in LDS memory, dst of blockwise copy
121  }
122 
123  __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
124  {
125  // B matrix in LDS memory, dst of blockwise copy
129  }
130 
131  __host__ __device__ static constexpr auto
133  {
134  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
135  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
136 
137  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
139  make_tuple(I1,
141  I1,
143 
144  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
145  }
146 
147  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
148  {
149  // LDS allocation for A and B: be careful of alignment
150  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
151  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
152 
153  // lds max alignment
154  constexpr auto max_lds_align = math::lcm(AK1, BK1);
155 
156  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
157  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
158 
159  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
160  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
161 
162  // LDS allocation for C shuffle in LDS
163  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
165 
166  constexpr auto c_block_size =
167  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
168 
169  return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
170  sizeof(ABDataType),
171  c_block_size * sizeof(EDataTypeShuffle));
172  }
173 
174  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
175  template <typename Block2ETileMap>
176  __host__ __device__ static constexpr bool
177  CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
178  const BGridDesc_N_K& b_grid_desc_n_k,
179  const EGridDesc_M_N& e_grid_desc_m_n,
180  const Block2ETileMap& /*block_2_etile_map*/)
181  {
182  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
183  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
184  "Invalid tuning param!");
185 
186  const auto M = a_grid_desc_m_k.GetLength(I0);
187  const auto N = b_grid_desc_n_k.GetLength(I0);
188  const auto K = a_grid_desc_m_k.GetLength(I1);
189 
190  // check consistency of desc
191  if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
192  K == b_grid_desc_n_k.GetLength(I1)))
193  {
194  return false;
195  }
196 
197  // check tile size
198  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
199  {
200  return false;
201  }
202 
203  // check gridwise gemm pipeline
204  const auto num_k_loop = K / KPerBlock;
205 
206  if(!GridwiseGemmMath::IsSupported(num_k_loop))
207  {
208  return false;
209  }
210 
211  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
212 
213  // check tensor size: cannot be larger than 2GB each
214  constexpr long_index_t TwoGB = (long_index_t{1} << 31);
215 
216  if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
217  b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
218  e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
219  {
220  return false;
221  }
222 
223  return true;
224  }
225 
226  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
227  {
228  const index_t num_loop = K / KPerBlock;
229 
230  return GridwiseGemmMath::CalculateHasMainLoop(num_loop);
231  }
232 
233  // return block_id to E matrix tile idx (m0, n0) mapping
234  __host__ __device__ static constexpr auto
235  MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
236  {
237  const auto M = e_grid_desc_m_n.GetLength(I0);
238  const auto N = e_grid_desc_m_n.GetLength(I1);
239 
240  constexpr auto M1 = Number<MPerBlock>{};
241  constexpr auto N1 = Number<NPerBlock>{};
242 
243  const auto M0 = M / M1;
244  const auto N0 = N / N1;
245 
246  constexpr auto M01 = I1;
247  constexpr auto N01 = I1;
248 
249  const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
255 
256  const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
258  make_tuple(make_merge_transform(make_tuple(M0, N0, M01, N01))),
261 
262  const auto cblockid_to_m0_n0_block_cluster_adaptor =
263  chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
264  cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
265 
266  return cblockid_to_m0_n0_block_cluster_adaptor;
267  }
268 
269  __host__ __device__ static constexpr index_t
270  CalculateGridSize(const EGridDesc_M_N& e_grid_desc_m_n)
271  {
272  const auto M = e_grid_desc_m_n.GetLength(I0);
273  const auto N = e_grid_desc_m_n.GetLength(I1);
274 
275  const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
276 
277  return grid_size;
278  }
279 
280  // A desc for source in blockwise copy
281  __host__ __device__ static constexpr auto
282  MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
283  {
284  const auto M = a_grid_desc_m_k.GetLength(I0);
285  const auto K = a_grid_desc_m_k.GetLength(I1);
286 
287  const auto AK0 = K / AK1;
288 
289  return transform_tensor_descriptor(a_grid_desc_m_k,
294  }
295 
296  // B desc for source in blockwise copy
297  __host__ __device__ static constexpr auto
298  MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
299  {
300  const auto N = b_grid_desc_n_k.GetLength(I0);
301  const auto K = b_grid_desc_n_k.GetLength(I1);
302 
303  const auto BK0 = K / BK1;
304 
305  return transform_tensor_descriptor(b_grid_desc_n_k,
310  }
311 
312  // E desc for destination in blockwise copy
313  template <typename EGridDescriptor_M_N>
314  __host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
315  const EGridDescriptor_M_N& e_grid_desc_m_n)
316  {
317  const auto M = e_grid_desc_m_n.GetLength(I0);
318  const auto N = e_grid_desc_m_n.GetLength(I1);
319 
320  const auto MBlock = M / MPerBlock;
321  const auto NBlock = N / NPerBlock;
322 
323  const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
324  e_grid_desc_m_n,
329 
330  return e_grid_desc_mblock_mperblock_nblock_nperblock;
331  }
332 
335  EGridDesc_M_N{}))>;
336 
338  remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
339 
340  template <bool HasMainKBlockLoop,
341  typename AGridDesc_AK0_M_AK1,
342  typename BGridDesc_BK0_N_BK1,
343  typename Block2ETileMap>
344  __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
345  const ABDataType* __restrict__ p_b_grid,
346  EDataType* __restrict__ p_e_grid,
347  void* __restrict__ p_shared,
348  const AElementwiseOperation& a_element_op,
349  const BElementwiseOperation& b_element_op,
350  const EElementwiseOperation& e_element_op,
351  const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
352  const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
354  e_grid_desc_mblock_mperblock_nblock_nperblock,
355  const Block2ETileMap& block_2_etile_map)
356  {
357  // build loadWave and MathWave pipelines
358  // loadWave and MathWave synchronized through LDS
359 
360  // A matrix in LDS memory, dst of blockwise copy
361  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
362 
363  // B matrix in LDS memory, dst of blockwise copy
364  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
365 
366  // lds max alignment
367  constexpr auto max_lds_align = math::lcm(AK1, BK1);
368 
369  // LDS allocation for A and B: be careful of alignment
370  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
371  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
372 
373  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
374  static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
375 
376  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
377  static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
378  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
379 
380  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
381  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
382 
383  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
384  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
385  KPerBlock);
386 
387  // divide block work by [M, N]
388  const auto block_work_idx =
389  block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
390 
391  // HACK: this force m/n_block_data_idx_on_grid into SGPR
392  const index_t m_block_data_idx_on_grid =
393  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
394 
395  const index_t n_block_data_idx_on_grid =
396  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
397 
399  {
400 
401  // LoadWave
402  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
403  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
404  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
405  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
406 
407  // A matrix blockwise copy
408  auto a_blockwise_copy =
410  AElementwiseOperation,
414  ABlockTransferThreadClusterLengths_AK0_M_AK1,
415  ABlockTransferThreadClusterArrangeOrder,
416  ABDataType,
417  ABDataType,
418  decltype(a_grid_desc_ak0_m_ak1),
419  decltype(a_block_desc_ak0_m_ak1),
420  ABlockTransferSrcAccessOrder,
422  ABlockTransferSrcVectorDim,
423  2,
424  ABlockTransferSrcScalarPerVector,
425  ABlockTransferDstScalarPerVector_AK1,
426  1,
427  1,
428  AThreadTransferSrcResetCoordinateAfterRun,
429  true,
430  NumGemmKPrefetchStage>(
431  a_grid_desc_ak0_m_ak1,
432  make_multi_index(0, m_block_data_idx_on_grid, 0),
433  a_element_op,
434  a_block_desc_ak0_m_ak1,
435  make_multi_index(0, 0, 0),
437 
438  // B matrix blockwise copy
439  auto b_blockwise_copy =
441  BElementwiseOperation,
445  BBlockTransferThreadClusterLengths_BK0_N_BK1,
446  BBlockTransferThreadClusterArrangeOrder,
447  ABDataType,
448  ABDataType,
449  decltype(b_grid_desc_bk0_n_bk1),
450  decltype(b_block_desc_bk0_n_bk1),
451  BBlockTransferSrcAccessOrder,
453  BBlockTransferSrcVectorDim,
454  2,
455  BBlockTransferSrcScalarPerVector,
456  BBlockTransferDstScalarPerVector_BK1,
457  1,
458  1,
459  BThreadTransferSrcResetCoordinateAfterRun,
460  true,
461  NumGemmKPrefetchStage>(
462  b_grid_desc_bk0_n_bk1,
463  make_multi_index(0, n_block_data_idx_on_grid, 0),
464  b_element_op,
465  b_block_desc_bk0_n_bk1,
466  make_multi_index(0, 0, 0),
468 
469  GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(
470  a_grid_desc_ak0_m_ak1,
471  a_block_desc_ak0_m_ak1,
472  a_blockwise_copy,
473  a_grid_buf,
474  a_block_buf,
475  a_block_slice_copy_step,
476  b_grid_desc_bk0_n_bk1,
477  b_block_desc_bk0_n_bk1,
478  b_blockwise_copy,
479  b_grid_buf,
480  b_block_buf,
481  b_block_slice_copy_step,
482  num_k_block_main_loop);
483 
484  block_sync_lds();
485  block_sync_lds();
486  }
488  {
489  // branch early for math wave
490  constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
491  constexpr bool is_single_rate_mfma =
493  lcm_AK1_BK1 <= 4) ||
494  (is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
496  lcm_AK1_BK1 < 32))
497  ? true
498  : false;
499  constexpr auto is_scale_mfma = false;
500  constexpr index_t KPack =
501  math::max(lcm_AK1_BK1,
502  MfmaSelector<ABDataType,
503  MPerXdl,
504  NPerXdl,
505  ABDataType,
506  is_single_rate_mfma,
507  is_scale_mfma>::selected_mfma.k_per_blk);
508 
510  TileMathThreadGroupSize,
511  ABDataType,
512  ABDataType,
513  FloatGemmAcc,
514  decltype(a_block_desc_ak0_m_ak1),
515  decltype(b_block_desc_bk0_n_bk1),
516  MPerXdl,
517  NPerXdl,
518  MXdlPerWave,
519  NXdlPerWave,
520  KPack>{};
521 
522  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
523  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
524  p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
525 
526  // TODO re-architect LDS+math stages
527  // Writing data to GMEM: only math wave is doing the work in cshuffle
528  GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(
529  a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop);
530 
531  // GEMM definition
532  // c_mtx += transpose(a_mtx) * b_mtx
533  // a_mtx[K0PerBlock, MPerBlock] is in LDS
534  // b_mtx[K0PerBlock, NPerBlock] is in LDS
535  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
536  // register
537  // sanity check
538 
539  // shuffle C and write out
540  {
541  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
542  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
543  "wrong!");
544 
545  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
546  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
547 
548  // TODO: hacky, fix it!
549  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
550  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
551 
552  // TODO: hacky, fix it!
553  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
554  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
555  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
556 
557  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
558  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
559  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
560  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
561  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
562  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
563  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
564  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
565 
566  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
568 
569  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
570  static_cast<EDataTypeShuffle*>(p_shared),
571  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
572 
573  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
574  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
575  make_tuple(
578  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
579  M1, // M1 = MWave
580  M2, // M2 * M3 * M4 = MPerXdl
581  M3,
582  M4)),
585  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
586  N1, // N1 = NWave
587  N2))), // N2 = NPerXdl
591  Sequence<>{},
592  Sequence<1, 3, 7>{}));
593 
594  // calculate origin of thread output tensor on global memory
595  // blockwise GEMM c matrix starting index
596  const auto c_thread_mtx_on_block =
597  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
598 
599  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
600  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
601 
602  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
604  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
607 
608  const auto m_thread_data_on_block_idx =
609  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
610  make_multi_index(m_thread_data_on_block));
611 
612  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
617 
618  const auto n_thread_data_on_block_idx =
619  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
620  make_multi_index(n_thread_data_on_block));
621 
622  // shuffle: threadwise copy C from VGPR to LDS
623  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
624  FloatGemmAcc,
625  EDataTypeShuffle,
626  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
627  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
629  Sequence<CShuffleMXdlPerWavePerShuffle,
630  CShuffleNXdlPerWavePerShuffle,
631  I1,
632  I1,
633  M2,
634  I1,
635  M4,
636  I1>,
638  7,
639  1,
641  1,
642  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
644  0,
645  m_thread_data_on_block_idx[I1],
646  n_thread_data_on_block_idx[I1],
647  m_thread_data_on_block_idx[I2],
648  m_thread_data_on_block_idx[I3],
649  m_thread_data_on_block_idx[I4],
650  n_thread_data_on_block_idx[I2]),
652 
653  // shuffle: blockwise copy C from LDS to global
654  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
655  CShuffleBlockTransferThreadGroup, // ThreadGroup
656  EElementwiseOperation, // ElementwiseOperation,
657  CGlobalMemoryDataOperation, // DstInMemOp,
658  Sequence<1,
659  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
660  1,
661  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
662  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
663  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
664  EDataTypeShuffle, // typename SrcData,
665  EDataType, // typename DstData,
666  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
667  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
668  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
669  3, // index_t VectorDim,
670  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
671  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
672  false> // bool ThreadTransferDstResetCoordinateAfterRun>
673  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
674  make_multi_index(0, 0, 0, 0),
675  e_grid_desc_mblock_mperblock_nblock_nperblock,
676  make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
677  e_element_op};
678 
679  // space filling curve for threadwise C in VGPR
680  constexpr auto sfc_c_vgpr =
683  Sequence<CShuffleMXdlPerWavePerShuffle,
684  CShuffleNXdlPerWavePerShuffle,
685  1,
686  1,
687  M2,
688  1,
689  M4,
690  1>>{};
691 
692  // space filling curve for shuffled blockwise C in global mem
693  constexpr auto sfc_c_global =
696  Sequence<1,
697  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
698  1,
699  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
700 
701  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
702 
703  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
704 
705  // Different way of getting coalesced writes:
706  // We can get rid of doing cshuffle. Instead of reading A rows in contiguous manner
707  // do it interleaved, then mfma can have nice c-mat layout as below:
708  //
709  // TODO
710  // We do not need to do LDS swizzle to align global writes writing cache lines:
711  // v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN
712  // elments (N is vertical or strided
713  // dimension)
714  // v_mfma cmat, bmat, amat, cmat - c-mat register layout are Mx1
715  // elments (M is coalescing
716  // dimension) by enumerating M index in
717  // amat, bmat you can align cmat
718  // register(s) to contiguous M elements
719  // for example
720  // 1st mfma instruction output space : 0 4 8 12 16 ....
721  // 2nd mfma instruction output space : 1 5 9 13 17 ....
722  // 3rd mfma instruction output space : 2 6 10 14 18 ....
723  // 4th mfma instruction output space : 3 7 11 15 19 ....
724  // you can pack 4 registers output space into 2WORD and do global write
725  // (no LDS swizzling required)
726 
727  static_for<0, num_access, 1>{}([&](auto access_id) {
728  // make sure it's safe to write to LDS
729  block_sync_lds();
730 
731  // each thread write its data from VGPR to LDS
732  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
733  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
734  c_thread_buf,
735  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
736  c_shuffle_block_buf);
737  // make sure it's safe to read from LDS
738  block_sync_lds();
739 
740  // each block copy its data from LDS to global
741  c_shuffle_block_copy_lds_to_global.Run(
742  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
743  c_shuffle_block_buf,
744  e_grid_desc_mblock_mperblock_nblock_nperblock,
745  c_grid_buf);
746 
747  if constexpr(access_id < num_access - 1)
748  {
749  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
750 
751  // move on C
752  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
753  e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
754  }
755  });
756  }
757  }
758  }
759 };
760 
761 } // namespace ck
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition: tensor_adaptor.hpp:245
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_smfmac_xdlops.hpp:79
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:82
static __device__ index_t GetThreadId()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:90
static constexpr __device__ bool IsBelong()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:85
static constexpr __device__ index_t GetNumOfThread()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:83
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:97
static __device__ index_t GetThreadId()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:105
static constexpr __device__ index_t GetNumOfThread()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:98
static constexpr __device__ bool IsBelong()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:100
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:64
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:235
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:132
static constexpr auto I3
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:69
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:147
static constexpr auto AK0PerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:78
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDescriptor_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:314
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:298
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:177
static constexpr auto BK0PerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:79
static constexpr auto I7
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:73
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:115
static constexpr auto I0
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:66
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:123
ThisThreadBlock< TileMathThreadGroupSize > CShuffleBlockTransferThreadGroup
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:108
remove_cvref_t< decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> DefaultBlock2ETileMap
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:338
static __device__ void Run(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const EElementwiseOperation &e_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:344
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:226
static constexpr auto AK1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:76
static constexpr auto I2
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:68
__host__ static constexpr __device__ index_t CalculateGridSize(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:270
remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:335
static constexpr auto I5
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:71
static constexpr auto I6
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:72
static constexpr auto I4
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:70
static constexpr auto BK1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:77
static constexpr auto I1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:67
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:282
Definition: gridwise_gemm_waveletmodel.hpp:11
Definition: gridwise_gemm_waveletmodel.hpp:103
Definition: xdlops_gemm.hpp:1126
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: thread_group.hpp:12
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:308