/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp Source File
gridwise_ab_transfer_thread_tiles.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
9 
10 namespace ck {
11 
12 template <typename ABLayout,
13  typename ABMajorLayout,
14  typename LDSTypeAB,
15  index_t BlockSize,
16  index_t MNPerBlock,
17  index_t KPerBlock,
18  index_t MNPerWmma,
19  index_t ABK1Value,
20  index_t KPack,
21  index_t KInner,
22  index_t KPerWmmaBlk,
23  bool UseBlockPaddingAB,
24  bool PermuteAB,
25  typename ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
26  typename ABBlockTransferThreadClusterArrangeOrder,
27  typename ABBlockTransferSrcAccessOrder,
28  index_t ABBlockTransferSrcVectorDim,
29  index_t ABBlockTransferSrcScalarPerVector,
30  index_t ABBlockTransferDstScalarPerVector_ABK1,
31  bool ABThreadTransferSrcResetCoordinateAfterRun>
33 {
34  __device__ static constexpr bool IsLDSNeeded() { return true; }
35 
36  static constexpr auto ABK0Number = Number<KPerBlock / ABK1Value>{};
37  static constexpr auto ABK1Number = Number<ABK1Value>{};
38 
39  static constexpr auto I0 = Number<0>{};
40  static constexpr auto I1 = Number<1>{};
41  static constexpr auto I2 = Number<2>{};
42 
43  static constexpr index_t ABPackedSize = []() {
45  return 2;
46  else
47  return 1;
48  }();
49 
51 
52  template <bool PadMN, bool PadK, typename GridDescriptorBase>
53  __host__ __device__ static auto MakeGridDescriptor(const GridDescriptorBase& ab_grid_desc,
54  index_t MN,
55  index_t MNPad,
56  index_t K,
57  index_t KPad,
58  index_t StrideAB,
59  index_t ABK0)
60  {
61 
62  if constexpr(PadMN && PadK)
63  {
64  // pad both MN and K
65  const auto ab_grid_desc_n_k =
66  transform_tensor_descriptor(ab_grid_desc,
67  make_tuple(make_right_pad_transform(MN, MNPad - MN),
68  make_right_pad_transform(K, KPad - K)),
71 
72  const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
73  ab_grid_desc_n_k,
78 
79  return ab_grid_desc_bk0_n_bk1;
80  }
81  else if constexpr(PadMN && !PadK)
82  {
83  // pad MN, but not K
84  const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
85  ab_grid_desc,
87  make_right_pad_transform(MN, MNPad - MN)),
90 
91  return ab_grid_desc_bk0_n_bk1;
92  }
93  else if constexpr(!PadMN && PadK)
94  {
95  // pad K, but not MN
96  const auto ab_grid_desc_n_k = transform_tensor_descriptor(
97  ab_grid_desc,
101 
102  const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
103  ab_grid_desc_n_k,
104  make_tuple(make_unmerge_transform(make_tuple(ABK0, ABK1Value)),
108 
109  return ab_grid_desc_bk0_n_bk1;
110  }
111  else
112  {
113  if constexpr(!PermuteAB)
114  {
115  // not pad MN or K
116  const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
117  ab_grid_desc,
118  make_tuple(make_unmerge_transform(make_tuple(ABK0, ABK1Value)),
122 
123  return ab_grid_desc_bk0_n_bk1;
124  }
125  else
126  {
127  // Pre-shuffled Weight
128  // BGlobal[K / KPerBlock, MN, KPerBlock / K1, K1] -> BTile[K / K1, MN, K1]
129  constexpr index_t ABK01 = KPerBlock / ABK1Value;
130  const index_t ABK0_ = StrideAB / ABK1Value;
131  const index_t ABK00 = ABK0_ / ABK01;
132 
133  const auto ab_grid_desc_abk00_mn_abk01_abk1_permute =
134  make_naive_tensor_descriptor_packed(make_tuple(ABK00, MN, ABK01, ABK1Value));
135 
136  const auto ab_grid_desc_abk0_mn_abk1_permute = transform_tensor_descriptor(
137  ab_grid_desc_abk00_mn_abk01_abk1_permute,
140  make_pass_through_transform(ABK1Value)),
143 
144  return ab_grid_desc_abk0_mn_abk1_permute;
145  }
146  }
147  }
148 
149  __device__ static constexpr auto GetBlockDescriptor()
150  {
151  // A matrix in LDS memory, dst of blockwise copy
152  if constexpr(UseBlockPaddingAB)
153  {
154  // bank conflict when writting the data into LDS, but don't worry, we have whole entire
155  // loop to hide it in v4. it may give you some benefit from less valu in compute address
159  }
160  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
161  // in some cases.
162  else if constexpr(is_same<ABMajorLayout, ABLayout>::value)
163  {
164  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeAB) / ABPackedSize;
165  constexpr auto MNLdsLayer = LdsSize < 1 ? 1 : LdsSize;
166  constexpr auto ab_lds_block_desc = make_naive_tensor_descriptor(
168  Number<MNPerBlock / MNLdsLayer>{},
169  ABK1Number),
171 
172  constexpr auto ab_lds_block_desc_permuted = transform_tensor_descriptor(
173  ab_lds_block_desc,
174  make_tuple(
180 
181  constexpr auto ab_lds_block_desc_abk0_mnldslayer_mn_abk1 = transform_tensor_descriptor(
182  ab_lds_block_desc_permuted,
188 
189  constexpr auto ab_lds_block_desc_abk0_mn_abk1 = transform_tensor_descriptor(
190  ab_lds_block_desc_abk0_mnldslayer_mn_abk1,
197 
198  return ab_lds_block_desc_abk0_mn_abk1;
199  }
200  else
201  {
202  // kfold and mpair dimension is not always required.
203  // more dimension in merge_transform increase the difficulty of generating immarg offset
204  // for compiler.
205  constexpr auto MN0 = ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1{}.At(I1);
206  constexpr auto MN1 = MNPerBlock / MN0;
207 
208  constexpr auto KThreadWrite = ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1{}.At(I0);
209  constexpr auto K0PerThreadWrite = ABK0Number / KThreadWrite;
210  constexpr auto KThreadRead = 64 / MNPerWmma;
211  constexpr auto K0PerThreadRead = ABK0Number / KThreadRead;
212 
213  constexpr auto kfold = (ABK1Number * MN0 * sizeof(LDSTypeAB) > 128)
214  ? 1
215  : 128 / (ABK1Number * MN0 * sizeof(LDSTypeAB));
216  constexpr auto KThreadReadPerm =
217  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
218  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
219  : KThreadRead;
220 
221  // 1<=mpair<=n0
222  constexpr auto mpair = (ABK1Number * MNPerWmma * sizeof(LDSTypeAB) > 128)
223  ? 1
224  : ((128 / (ABK1Number * MNPerWmma * sizeof(LDSTypeAB))) > MN0
225  ? MN0
226  : 128 / (ABK1Number * MNPerWmma * sizeof(LDSTypeAB)));
227 
228  constexpr auto ab_lds_block_desc = make_naive_tensor_descriptor_packed(
232  Number<kfold * MN0 / mpair>{},
233  Number<mpair>{},
234  ABK1Number));
235 
236  constexpr auto ab_lds_block_desc_permuted = transform_tensor_descriptor(
237  ab_lds_block_desc,
238  make_tuple(
242  make_tuple(Number<KThreadReadPerm * MN1>{}, Number<kfold * MN0 / mpair>{})),
245  make_tuple(
247  make_tuple(
249 
250  constexpr auto ab_lds_block_desc_unmerged = transform_tensor_descriptor(
251  ab_lds_block_desc_permuted,
252  make_tuple(
260  Sequence<1>{},
261  Sequence<2>{},
262  Sequence<3>{},
263  Sequence<4>{},
264  Sequence<5>{}),
266  Sequence<2>{},
267  Sequence<0, 3>{},
268  Sequence<4, 5>{},
269  Sequence<6>{},
270  Sequence<7>{}));
271 
272  constexpr auto ab_lds_block_desc_abk0_mn_abk1 = transform_tensor_descriptor(
273  ab_lds_block_desc_unmerged,
276  Number<KThreadWrite / kfold / KThreadReadPerm>{},
277  Number<kfold>{},
284 
285  return ab_lds_block_desc_abk0_mn_abk1;
286  }
287  }
288 
289  template <typename GridDescriptor,
290  typename BlockDescriptor,
291  typename ABsDataType,
292  typename ABElementwiseOperation,
293  index_t GlobalBufferNum>
294  __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
295  BlockDescriptor& block_descriptor,
296  ABElementwiseOperation& ab_element_op,
297  const index_t block_mn_id,
298  const index_t)
299  {
300  constexpr index_t NumABTensor = ABsDataType::Size();
301  const index_t mn_block_data_idx_on_grid =
302  __builtin_amdgcn_readfirstlane(block_mn_id * MNPerBlock);
303  // workaround because v7r2 is not as general as v4r1
304  if constexpr(NumABTensor > 1)
305  {
306  const auto idx_as_block_begin = generate_tuple(
307  [&](auto) { return make_multi_index(0, mn_block_data_idx_on_grid, 0); },
309 
312  ABsDataType,
314  GridDescriptor,
315  decltype(tie(block_descriptor)),
316  ABElementwiseOperation,
319  ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
320  ABBlockTransferThreadClusterArrangeOrder,
321  ABBlockTransferSrcAccessOrder,
323  ABBlockTransferSrcVectorDim,
324  2,
325  ABBlockTransferSrcScalarPerVector,
326  ABBlockTransferDstScalarPerVector_ABK1,
329  GlobalBufferNum>{grid_descriptor,
330  idx_as_block_begin,
331  tie(block_descriptor),
332  make_tuple(make_multi_index(0, 0, 0)),
333  ab_element_op};
334  }
335  else
336  {
339  ABElementwiseOperation,
343  ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
344  ABBlockTransferThreadClusterArrangeOrder,
347  decltype(grid_descriptor[I0]),
348  decltype(block_descriptor),
349  ABBlockTransferSrcAccessOrder,
351  ABBlockTransferSrcVectorDim,
352  2,
353  ABBlockTransferSrcScalarPerVector,
354  ABBlockTransferDstScalarPerVector_ABK1,
355  1,
356  1,
357  ABThreadTransferSrcResetCoordinateAfterRun,
358  true,
359  GlobalBufferNum>(grid_descriptor[I0],
360  make_multi_index(0, mn_block_data_idx_on_grid, 0),
361  ab_element_op,
362  block_descriptor,
363  make_multi_index(0, 0, 0),
365  }
366  }
367 
368  template <index_t MNRepeat, index_t MNWaves>
369  __host__ __device__ static constexpr auto MakeWmmaTileDescriptor()
370  {
371  // This is a block descriptor used to read LDS memory into register
372  // It's defined in a way consistent with the existing implementation to
373  // avoid changes in the pipelines
374  using BlockDesc = decltype(GetBlockDescriptor());
375  // ABK0_MN_ABK1 -> ABK0_MNRepeat_MNWaves_KRow_MNPerWmma_ABK1
376  constexpr auto ABK0 = BlockDesc{}.GetLength(I0);
377  constexpr auto ABK1 = BlockDesc{}.GetLength(I2);
378 #ifdef __gfx12__
379  constexpr auto KRow = I2;
380 #else
381  constexpr auto KRow = I1;
382 #endif
383  if constexpr(KInner > 1)
384  {
385  // KPack = KInner * KPerWmma
386  // K1 = KInner * KPerWmmaBlk
387  // Each thread loads multiple tiles with one instruction
388  // 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
390  BlockDesc{},
391  make_tuple(
398  }
399  else
400  {
401  // KPack = KPerWmma (KInner == 1)
402  if constexpr(ABK1 <= KPerWmmaBlk)
403  {
404  // K1 <= single tile (KPerWmmaBlk)
405  // Each thread will load KPerWmmaBlk for the WMMA instruction
406  // Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1
407  // (rest of the single WMMA tile for single thread) and then over KRow
408  // (rest of the single WMMA tile for single wave)
409  // KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
411  BlockDesc{},
412  make_tuple(
414  Number<ABK0 / (KPack / ABK1)>{}, KRow, Number<KPack / KRow / ABK1>{})),
420  }
421  else
422  {
423  // K1 > single tile (KPerWmmaBlk)
424  // Each thread will load KPerWmmaBlk for the WMMA instruction
425  // Since K1 > single tile, each thread loads KPerWmmaBlk and the next
426  // KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout).
427  // This layout is needed to support for example AK1 > single tile and
428  // BK1 <= single tile in the same gemm
429  // KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma -
430  // K1
431  constexpr auto desc1 = transform_tensor_descriptor(
432  BlockDesc{},
433  make_tuple(
438  Number<KPack / KPerWmmaBlk / KRow>{},
439  Number<KRow>{},
440  Number<KPerWmmaBlk>{}))),
443 
445  desc1,
446  make_tuple(
449  make_merge_transform(make_tuple(Number<ABK0>{}, Number<ABK1 / KPack>{})),
455  Sequence<1>{},
456  Sequence<2, 3>{},
457  Sequence<4>{},
458  Sequence<5>{},
459  Sequence<6>{},
460  Sequence<7>{}),
462  Sequence<1>{},
463  Sequence<2>{},
464  Sequence<3>{},
465  Sequence<4>{},
466  Sequence<5>{},
467  Sequence<6>{}));
468  }
469  }
470  }
471 
472  __device__ static constexpr auto GetBlockStep()
473  {
474  // Grid descriptor step (MoveSrcSliceWindow)
475  return make_multi_index(KPerBlock / ABK1Number, 0, 0);
476  }
477 
478  template <typename GridDescriptor>
479  __device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc)
480  {
481  // K dimension size. This should always be called with the A matrix grid descriptor
482  // because it doesn't work for B matrix when packed int4 is used
483  return grid_desc.GetLength(I0) * grid_desc.GetLength(I2);
484  }
485 
486  template <typename LDSType, typename IndexType>
487  __device__ static auto GetBuffer(LDSType* p_shared_AB, const IndexType& size)
488  {
489  return make_dynamic_buffer<AddressSpaceEnum::Lds>(p_shared_AB, size);
490  }
491 };
492 
493 } // namespace ck
auto grid_desc(MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType > matrix_padder, CDesc_MRaw_NRaw conv_desc)
Definition: matrix_padder.hpp:190
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:54
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: gridwise_ab_transfer_thread_tiles.hpp:33
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_ab_transfer_thread_tiles.hpp:50
static constexpr auto I1
Definition: gridwise_ab_transfer_thread_tiles.hpp:40
static constexpr auto ABK0Number
Definition: gridwise_ab_transfer_thread_tiles.hpp:36
static constexpr __device__ auto GetBlockStep()
Definition: gridwise_ab_transfer_thread_tiles.hpp:472
__host__ static __device__ auto MakeGridDescriptor(const GridDescriptorBase &ab_grid_desc, index_t MN, index_t MNPad, index_t K, index_t KPad, index_t StrideAB, index_t ABK0)
Definition: gridwise_ab_transfer_thread_tiles.hpp:53
static __device__ auto GetBlockTransfer(GridDescriptor &grid_descriptor, BlockDescriptor &block_descriptor, ABElementwiseOperation &ab_element_op, const index_t block_mn_id, const index_t)
Definition: gridwise_ab_transfer_thread_tiles.hpp:294
static constexpr auto I0
Definition: gridwise_ab_transfer_thread_tiles.hpp:39
static constexpr __device__ bool IsLDSNeeded()
Definition: gridwise_ab_transfer_thread_tiles.hpp:34
static constexpr index_t ABPackedSize
Definition: gridwise_ab_transfer_thread_tiles.hpp:43
static __device__ auto GetBuffer(LDSType *p_shared_AB, const IndexType &size)
Definition: gridwise_ab_transfer_thread_tiles.hpp:487
static constexpr __device__ auto GetBlockDescriptor()
Definition: gridwise_ab_transfer_thread_tiles.hpp:149
static constexpr auto ABK1Number
Definition: gridwise_ab_transfer_thread_tiles.hpp:37
__host__ static constexpr __device__ auto MakeWmmaTileDescriptor()
Definition: gridwise_ab_transfer_thread_tiles.hpp:369
static constexpr __device__ index_t GetKDimension(const GridDescriptor &grid_desc)
Definition: gridwise_ab_transfer_thread_tiles.hpp:479
static constexpr auto I2
Definition: gridwise_ab_transfer_thread_tiles.hpp:41
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v7r2.hpp:47
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:187
Definition: unary_element_wise_operation.hpp:340