/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp Source File
gridwise_ab_transfer_wave_tiles.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
8 #include "ck/utility/math.hpp"
9 
10 namespace ck {
11 
12 template <typename ABLayout,
13  typename ABMajorLayout,
14  typename LDSTypeAB,
15  index_t BlockSize,
16  index_t MNPerBlock,
17  index_t KPerBlock,
18  index_t MNPerWmma,
19  index_t KPack,
20  index_t ABK1Value,
21  index_t WaveSize>
23 {
24  __device__ static constexpr bool IsLDSNeeded() { return true; }
25 
26  static_assert(!(is_same_v<remove_cvref_t<LDSTypeAB>, pk_i4_t>),
27  "wave tile transfer method does not support pk_i4_t");
28  static constexpr auto I0 = Number<0>{};
29  static constexpr auto I1 = Number<1>{};
30  static constexpr auto I2 = Number<2>{};
31  static constexpr auto I3 = Number<3>{};
32 
33  static constexpr index_t MNKRow = 2;
34 
36 
37  // Tiles distribution for global memory loading
38  // Notes: support for not power of 2 needs to be reviewed later on
39  // The tiles are distributed along the non-contiguous matrix dimension
40  // Example 4 waves A row-major MPerBlock = 64, KPerBlock = 64
41  // MRepeat = 1, KRepeat = 4
42  // -------------
43  // |W0| | | |
44  // -------------
45  // |W1| | | |
46  // -------------
47  // |W2| | | |
48  // -------------
49  // |W3| | | |
50  // -------------
51  // Example 4 waves A column-major MPerBlock = 64, KPerBlock = 64
52  // MRepeat = 4, KRepeat = 1
53  // -------------
54  // |W0|W1|W2|W3|
55  // -------------
56  // | | | | |
57  // -------------
58  // | | | | |
59  // -------------
60  // | | | | |
61  // -------------
62  static constexpr index_t NumberOfWaves = BlockSize / WaveSize;
63  static constexpr index_t MNMajorWaves_ =
64  MNPerBlock / MNPerWmma % std::min(MNPerBlock / MNPerWmma, NumberOfWaves) == 0
65  ? std::min(MNPerBlock / MNPerWmma, NumberOfWaves)
66  : (MNPerBlock / MNPerWmma % 2 == 0 ? 2 : 1);
67  static constexpr index_t KMajorWaves_ =
68  KPerBlock / KPack % std::min(KPerBlock / KPack, NumberOfWaves) == 0
69  ? std::min(KPerBlock / KPack, NumberOfWaves)
70  : (KPerBlock / KPack % 2 == 0 ? 2 : 1);
71 
72  static constexpr bool ABDoTranspose = !is_same_v<ABLayout, ABMajorLayout>;
73 
74  static constexpr index_t MNWaves_ =
77  static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack);
78  static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma);
79 
80  template <bool PadMN, bool PadK, typename GridDescriptorBase>
81  __host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc,
82  index_t sizeMN,
83  index_t,
84  index_t sizeK,
85  index_t,
86  index_t,
87  index_t)
88  {
89  // Notes: padding is currently not supported
90  static_assert(!PadMN && !PadK, "padding is currently not supported");
91 
92  // Divide the base descriptor MN_K into tiles
93  const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor(
94  base_desc,
95  make_tuple(
99  Number<KPack>{}))),
102 
103  // The distinction is needed to get the same global indices for both layouts
104  // Divide each tile in 2 16x8 subtile
105  // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
106  // MNKRow = 0-1
107  // LaneLocal = 0-15
108  // VectorSize must be 8
109  if constexpr(!ABDoTranspose)
110  {
111  const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 =
113  ab_grid_desc_mntiles_ktiles,
120  make_tuple(Number<MNKRow>{}, Number<KPack / MNKRow>{}))),
123 
124  // Freeze VectorSize to first element of the loading chunk (for convenience)
125  // Swap MNPerWmma and MNKRow for consistency with transpose descriptor
127  ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
128  make_tuple(
135  make_tuple(
137  make_tuple(
139  }
140  else
141  {
142  const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 =
144  ab_grid_desc_mntiles_ktiles,
150  make_tuple(Number<MNKRow>{}, Number<MNPerWmma / MNKRow>{})),
154 
155  // Freeze VectorSize to first element of the loading chunk (for convenience)
157  ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
158  make_tuple(
165  make_tuple(
167  make_tuple(
169  }
170  }
171 
172  __device__ static constexpr auto GetBlockDescriptor()
173  {
174  // LDS memory layouts:
175  // lanes within tiles stored contiguously in chunks of 8 elements
176  // tiles are then stored first in K dimension
177  // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
178  const auto a_grid_desc_mraw_kraw = [&]() {
182  Number<MNKRow>{},
189  I1));
190  }();
191 
192  // Freeze VectorSize to first element of the chunk (for convenience)
194  a_grid_desc_mraw_kraw,
202  }
203 
204  __device__ static auto GetWaveIdx()
205  {
206  const index_t thread_id = ThisThreadBlock::GetThreadId();
207 
208  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
212 
213  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
214  }
215 
216  __device__ static auto GetBlockLaneIdx()
217  {
218  const index_t lane_id = __lane_id();
219 
220  constexpr index_t LanesPerSubTile = ABDoTranspose ? KPack : MNPerWmma;
221 
222  constexpr auto laneid_to_block_lane_idx_adaptor = make_single_stage_tensor_adaptor(
223  make_tuple(make_merge_transform(make_tuple(MNKRow, LanesPerSubTile))),
226 
227  return laneid_to_block_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id));
228  }
229 
230  template <typename ABDataType>
231  __device__ static auto GetGridLaneIdx()
232  {
233  const index_t lane_id = __lane_id();
234 
235  constexpr index_t SubTilesRow = MNKRow;
236  constexpr index_t SubTilesCol = 4 / sizeof(ABDataType);
237  constexpr index_t LanesPerSubTile =
238  ABDoTranspose ? KPack / SubTilesCol : MNPerWmma / SubTilesCol;
239  constexpr auto dims_tuple = ABDoTranspose
240  ? make_tuple(SubTilesCol, SubTilesRow, LanesPerSubTile)
241  : make_tuple(SubTilesRow, SubTilesCol, LanesPerSubTile);
242 
243  constexpr auto laneid_to_grid_lane_idx_adaptor =
247 
248  const auto indices =
249  laneid_to_grid_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id));
250 
251  if constexpr(!ABDoTranspose)
252  {
253  return make_multi_index(indices[I0], indices[I1] * LanesPerSubTile + indices[I2]);
254  }
255  else
256  {
257  return make_multi_index(indices[I1], indices[I0] * LanesPerSubTile + indices[I2]);
258  }
259  }
260 
261  template <typename GridDescriptor,
262  typename BlockDescriptor,
263  typename ABsDataType,
264  typename ABElementwiseOperation,
265  index_t GlobalBufferNum>
266  __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
267  BlockDescriptor& block_descriptor,
268  ABElementwiseOperation& ab_element_op,
269  const index_t block_mn_id,
270  const index_t)
271  {
272  // Note: GlobalBufferNum is currently not used but it will be needed
273  // once we add other pipelines. It is currently needed only for
274  // consistency with the thread tiles approach
275  static_assert(GlobalBufferNum == 1, "single global buffer is only supported");
276  constexpr index_t NumABTensor = ABsDataType::Size();
277  static_assert(NumABTensor == 1, "multiAB currently not supported");
278 
280 
281  const auto wave_idx = GetWaveIdx();
282  index_t wave_idK = wave_idx[I1];
283  index_t wave_idMN = wave_idx[I0];
284 
285  const auto grid_lane_id = GetGridLaneIdx<ABDataType>();
286  index_t lane_group_grid = grid_lane_id[I0];
287  index_t lane_local_id_grid = grid_lane_id[I1];
288 
289  const auto block_lane_id = GetBlockLaneIdx();
290  index_t lane_group_block = block_lane_id[I0];
291  index_t lane_local_id_block = block_lane_id[I1];
292 
293  return ThreadGroupTransferGlobal<decltype(grid_descriptor[I0]),
294  BlockDescriptor,
295  ABDataType,
296  ABDataType,
297  ABElementwiseOperation,
301  ABK1Value,
302  ABDoTranspose>(
303  grid_descriptor[I0],
304  block_descriptor,
305  make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN,
306  wave_idK,
307  lane_group_grid,
308  lane_local_id_grid),
309  make_multi_index(wave_idMN, wave_idK, lane_group_block, lane_local_id_block),
310  ab_element_op);
311  }
312 
313  template <index_t MNRepeat, index_t MNWaves>
314  __host__ __device__ static constexpr auto MakeWmmaTileDescriptor()
315  {
316  // This is a block descriptor used to read LDS memory into register
317  // It's defined in a way consistent with the existing implementation to
318  // avoid changes in the pipelines
321  Number<KPerBlock / KPack>{},
322  Number<MNWaves>{},
323  Number<MNKRow>{},
326  make_tuple(I0,
332  I1));
333  }
334 
335  __device__ static constexpr auto GetBlockStep()
336  {
337  // Grid descriptor step (MoveSrcSliceWindow)
338  return make_multi_index(I0, KWaves_ * KRepeat_, I0, I0);
339  }
340 
341  template <typename GridDescriptor>
342  __device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc)
343  {
344  return grid_desc.GetLength(I1) * KPack;
345  }
346 
347  template <typename LDSType, typename IndexType>
348  __device__ static auto GetBuffer(LDSType* p_shared_AB, const IndexType& size)
349  {
350  return make_dynamic_buffer<AddressSpaceEnum::Lds>(p_shared_AB, size);
351  }
352 };
353 
354 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
auto grid_desc(MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType > matrix_padder, CDesc_MRaw_NRaw conv_desc)
Definition: matrix_padder.hpp:190
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: gridwise_ab_transfer_wave_tiles.hpp:23
static __device__ auto GetWaveIdx()
Definition: gridwise_ab_transfer_wave_tiles.hpp:204
static constexpr __device__ bool IsLDSNeeded()
Definition: gridwise_ab_transfer_wave_tiles.hpp:24
__host__ static __device__ auto MakeGridDescriptor(GridDescriptorBase &base_desc, index_t sizeMN, index_t, index_t sizeK, index_t, index_t, index_t)
Definition: gridwise_ab_transfer_wave_tiles.hpp:81
static constexpr index_t MNRepeat_
Definition: gridwise_ab_transfer_wave_tiles.hpp:78
static __device__ auto GetBuffer(LDSType *p_shared_AB, const IndexType &size)
Definition: gridwise_ab_transfer_wave_tiles.hpp:348
static __device__ auto GetGridLaneIdx()
Definition: gridwise_ab_transfer_wave_tiles.hpp:231
static constexpr __device__ auto GetBlockDescriptor()
Definition: gridwise_ab_transfer_wave_tiles.hpp:172
static __device__ auto GetBlockTransfer(GridDescriptor &grid_descriptor, BlockDescriptor &block_descriptor, ABElementwiseOperation &ab_element_op, const index_t block_mn_id, const index_t)
Definition: gridwise_ab_transfer_wave_tiles.hpp:266
static constexpr auto I2
Definition: gridwise_ab_transfer_wave_tiles.hpp:30
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_ab_transfer_wave_tiles.hpp:35
static constexpr index_t KWaves_
Definition: gridwise_ab_transfer_wave_tiles.hpp:76
__host__ static constexpr __device__ auto MakeWmmaTileDescriptor()
Definition: gridwise_ab_transfer_wave_tiles.hpp:314
static constexpr __device__ index_t GetKDimension(const GridDescriptor &grid_desc)
Definition: gridwise_ab_transfer_wave_tiles.hpp:342
static constexpr index_t KMajorWaves_
Definition: gridwise_ab_transfer_wave_tiles.hpp:67
static constexpr index_t MNMajorWaves_
Definition: gridwise_ab_transfer_wave_tiles.hpp:63
static constexpr auto I1
Definition: gridwise_ab_transfer_wave_tiles.hpp:29
static constexpr auto I3
Definition: gridwise_ab_transfer_wave_tiles.hpp:31
static constexpr index_t MNKRow
Definition: gridwise_ab_transfer_wave_tiles.hpp:33
static constexpr auto I0
Definition: gridwise_ab_transfer_wave_tiles.hpp:28
static constexpr bool ABDoTranspose
Definition: gridwise_ab_transfer_wave_tiles.hpp:72
static constexpr index_t MNWaves_
Definition: gridwise_ab_transfer_wave_tiles.hpp:74
static constexpr __device__ auto GetBlockStep()
Definition: gridwise_ab_transfer_wave_tiles.hpp:335
static constexpr index_t KRepeat_
Definition: gridwise_ab_transfer_wave_tiles.hpp:77
static constexpr index_t NumberOfWaves
Definition: gridwise_ab_transfer_wave_tiles.hpp:62
static __device__ auto GetBlockLaneIdx()
Definition: gridwise_ab_transfer_wave_tiles.hpp:216
Definition: sequence.hpp:43
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
Definition: thread_group_tensor_slice_transfer_global.hpp:26
Definition: integral_constant.hpp:20
Definition: data_type.hpp:187