/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp Source File
gridwise_permute.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <functional>
7 #include <numeric>
8 #include <iterator>
9 
11 #include "ck/utility/data_type.hpp"
15 
16 namespace ck {
17 
18 template <typename GridwisePermute,
19  typename InGridDesc,
20  typename OutGridDesc,
21  typename InDataType,
22  typename OutDataType,
23  typename ElementwiseOperation,
24  typename Block2TileMap>
25 __global__ void kernel_nd_permute(const InGridDesc in_grid_desc,
26  const OutGridDesc out_grid_desc,
27  const InDataType* p_in_global,
28  OutDataType* p_out_global,
29  const ElementwiseOperation elementwise_op,
30  const Block2TileMap block_2_tile_map)
31 {
32  __shared__ char p_shared[GridwisePermute::GetSharedMemoryNumberOfByte()];
33 
34  GridwisePermute::Run(in_grid_desc,
35  out_grid_desc,
36  p_in_global,
37  p_out_global,
38  p_shared,
39  elementwise_op,
40  block_2_tile_map);
41 }
42 
43 template <typename InGridDesc,
44  typename OutGridDesc,
45  typename InDataType,
46  typename OutDataType,
47  typename ElementwiseOperation,
48  index_t BlockSize,
49  index_t NPerBlock,
50  index_t HPerBlock,
51  index_t WPerBlock,
52  index_t InBlockLdsExtraW,
53  typename InBlockTransferThreadClusterLengths,
54  typename InBlockTransferThreadClusterArrangeOrder,
55  index_t SrcVectorDim,
56  index_t DstVectorDim,
57  index_t SrcScalarPerVector,
58  index_t DstScalarPerVector>
60 {
61  static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension());
62  static_assert(3 <= InGridDesc::GetNumOfDimension());
63  static_assert((InGridDesc::GetNumOfDimension() - 2) <= SrcVectorDim &&
64  SrcVectorDim < InGridDesc::GetNumOfDimension());
65  static_assert((OutGridDesc::GetNumOfDimension() - 2) <= DstVectorDim &&
66  DstVectorDim < OutGridDesc::GetNumOfDimension());
67  static_assert(SrcVectorDim != DstVectorDim);
68 
69  static constexpr auto I0 = Number<0>{};
70  static constexpr auto I1 = Number<1>{};
71  static constexpr auto I2 = Number<2>{};
72 
74 
76  {
77  static constexpr index_t NumDim = InGridDesc::GetNumOfDimension();
78  static_assert(3 <= NumDim);
79 
80  static constexpr auto I0 = Number<0>{};
81 
82  Block2TileMap() = delete;
83  Block2TileMap(const Block2TileMap&) = default;
85 
86  ~Block2TileMap() = default;
87 
90 
91  explicit Block2TileMap(const InGridDesc& desc) : desc_(desc) {}
92 
93  __host__ constexpr index_t CalculateGridSize(const InGridDesc& desc) const
94  {
95  const auto N0 =
96  math::integer_divide_ceil(desc.GetLength(Number<NumDim - 3>{}), NPerBlock);
97  const auto H0 =
98  math::integer_divide_ceil(desc.GetLength(Number<NumDim - 2>{}), HPerBlock);
99  const auto W0 =
100  math::integer_divide_ceil(desc.GetLength(Number<NumDim - 1>{}), WPerBlock);
101 
102  const index_t grid_size = N0 * H0 * W0;
103 
104  return grid_size;
105  }
106 
107  template <typename TopIdx>
108  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
109  {
110  static_assert(TopIdx::Size() == 1);
111 
112  auto block_1d_id = idx_top[I0];
113 
114  const auto N0 =
115  math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 3>{}), NPerBlock);
116  const auto H0 =
117  math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 2>{}), HPerBlock);
118  const auto W0 =
119  math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 1>{}), WPerBlock);
120 
121  block_1d_id = block_1d_id % (N0 * H0 * W0);
122 
123  index_t idx_N0 = block_1d_id / (H0 * W0);
124  index_t idx_H0 = (block_1d_id % (H0 * W0)) / W0;
125  index_t idx_W0 = block_1d_id % W0;
126 
127  return make_tuple(idx_N0, idx_H0, idx_W0);
128  }
129 
130  private:
131  const InGridDesc desc_;
132  };
133 
135 
136  // use an [NPerBlock, HPerBlock, WPerBlock] tensor as element-copy relay
137  __host__ __device__ static constexpr auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock()
138  {
141  make_tuple(Number<HPerBlock*(WPerBlock + InBlockLdsExtraW)>{},
143  I1));
144  }
145 
146  // for N-dimension descriptor, reserve its last 2 dimensions, then merge its leading dimensions
147  // into single one. finally, form a 3D descriptor: [d(0), d(1), ..., d(N - 2), d(N - 1)] ->
148  // [(d(0) x d(1) x ...), d(N - 2), d(N - 1)]
149  template <typename GridDesc>
150  __host__ __device__ static constexpr auto GetMergedDesc(const GridDesc& desc)
151  {
152  constexpr index_t NumDim = GridDesc::GetNumOfDimension();
153  static_assert(3 <= NumDim);
154 
155  const auto merged_desc = transform_tensor_descriptor(
156  desc,
158  [&](auto I) { return desc.GetLength(I); }, Number<NumDim - 2>{})),
161  make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
163  Sequence<NumDim - 1>{}),
165  return merged_desc;
166  }
167 
168  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
169  {
170  constexpr auto in_block_desc_nperblock_hperblock_wperblock =
172 
173  return in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize() *
174  sizeof(InDataType);
175  }
176 
177  __host__ __device__ static constexpr auto MakeDefaultBlock2TileMap(const InGridDesc& desc)
178  {
179  return DefaultBlock2TileMap{desc};
180  }
181 
182  __host__ __device__ static constexpr bool CheckValidity(const InGridDesc& in_grid_desc,
183  const OutGridDesc& out_grid_desc)
184  {
185  constexpr index_t NumDim = InGridDesc::GetNumOfDimension();
186 
187  // check if we only swap last 2 dimensions
188  bool valid = true;
189  static_for<0, NumDim - 2, 1>{}([&](auto I) {
190  if(valid && in_grid_desc.GetLength(I) != out_grid_desc.GetLength(I))
191  {
192  valid = false;
193  }
194  });
195 
196  return valid &&
197  (in_grid_desc.GetLength(Number<NumDim - 1>{}) ==
198  out_grid_desc.GetLength(Number<NumDim - 2>{})) &&
199  (in_grid_desc.GetLength(Number<NumDim - 2>{}) ==
200  out_grid_desc.GetLength(Number<NumDim - 1>{}));
201  }
202 
203  template <typename Block2TileMap>
204  __device__ static void Run(const InGridDesc in_grid_desc,
205  const OutGridDesc out_grid_desc,
206  const InDataType* p_in_global,
207  OutDataType* p_out_global,
208  void* __restrict__ p_shared,
209  const ElementwiseOperation elementwise_op,
210  const Block2TileMap& block_2_tile_map)
211  {
212  auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
213  p_in_global, in_grid_desc.GetElementSpaceSize());
214 
215  auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
216  p_out_global, out_grid_desc.GetElementSpaceSize());
217 
218  // each workgroup handles an [NPerBlock, HPerBlock, WPerBLock] slice-transpose problem
219  const auto block_work_idx =
221 
222  const index_t n_block_data_idx_on_grid =
223  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock);
224 
225  const index_t h_block_data_idx_on_grid =
226  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * HPerBlock);
227 
228  const index_t w_block_data_idx_on_grid =
229  __builtin_amdgcn_readfirstlane(block_work_idx[I2] * WPerBlock);
230 
231  // create [NPerBlock, HPerBlock, WPerBLock] shaped LDS buffer
232  constexpr auto in_block_desc_nperblock_hperblock_wperblock =
233  GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock();
234 
235  auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
236  static_cast<InDataType*>(p_shared),
237  in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize());
238 
239  using BlockSliceLengths = Sequence<NPerBlock, HPerBlock, WPerBlock>;
240  using InBlockTransferAccessOrder = Sequence<0, 1, 2>;
241 
242  constexpr index_t SrcVectorDimAfterMerge =
243  SrcVectorDim - (InGridDesc::GetNumOfDimension() - 3);
244  constexpr index_t DstVectorDimAfterMerge = SrcVectorDimAfterMerge;
245 
247 
248  // merge input descriptor into [(in_grid_desc.GetLength(0) x in_grid_desc.GetLength(1) x
249  // ...), in_grid_desc.GetLength(NumDim - 2), in_grid_desc.GetLength(NumDim - 1)]
250  const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc);
251 
252  // a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from global memory to LDS
253  auto in_global_load = ThreadGroupTensorSliceTransfer_v4r1<
255  ElementwiseOperation,
256  PassThrough,
257  InMemoryDataOperationEnum::Set,
258  BlockSliceLengths,
259  InBlockTransferThreadClusterLengths,
260  InBlockTransferThreadClusterArrangeOrder,
261  InDataType,
262  InDataType,
263  decltype(in_grid_desc_n_h_w),
264  decltype(in_block_desc_nperblock_hperblock_wperblock),
265  InBlockTransferAccessOrder,
266  InBlockTransferAccessOrder,
267  SrcVectorDimAfterMerge,
268  2,
269  SrcScalarPerVector,
270  1,
271  1,
272  1,
273  true,
274  true>(in_grid_desc_n_h_w,
276  n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
277  PassThrough{},
278  in_block_desc_nperblock_hperblock_wperblock,
279  make_multi_index(0, 0, 0),
280  PassThrough{});
281 
282  // merge output descriptor into [(out_grid_desc.GetLength(0) x out_grid_desc.GetLength(1) x
283  // ...), out_grid_desc.GetLength(NumDim - 2), out_grid_desc.GetLength(NumDim - 1)]
284  const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc);
285 
286  // create transposed view of output tensor
287  const auto out_grid_desc_n_h_w = transform_tensor_descriptor(
288  out_grid_desc_n_w_h,
289  make_tuple(make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I0)),
290  make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I1)),
291  make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I2))),
294 
295  // a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from LDS to global memory
296  auto out_global_store = ThreadGroupTensorSliceTransfer_v4r1<
298  ElementwiseOperation,
299  PassThrough,
300  InMemoryDataOperationEnum::Set,
301  BlockSliceLengths,
302  InBlockTransferThreadClusterLengths,
303  InBlockTransferThreadClusterArrangeOrder,
304  InDataType,
305  OutDataType,
306  decltype(in_block_desc_nperblock_hperblock_wperblock),
307  decltype(out_grid_desc_n_h_w),
308  InBlockTransferAccessOrder,
309  InBlockTransferAccessOrder,
310  2,
311  DstVectorDimAfterMerge,
312  1,
313  DstScalarPerVector,
314  1,
315  1,
316  true,
317  true>(in_block_desc_nperblock_hperblock_wperblock,
318  make_multi_index(0, 0, 0),
319  PassThrough{},
320  out_grid_desc_n_h_w,
322  n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
323  elementwise_op);
324 
325  in_global_load.Run(in_grid_desc_n_h_w,
326  in_global_buf,
327  in_block_desc_nperblock_hperblock_wperblock,
328  in_block_buf,
329  I0);
330 
331  out_global_store.Run(in_block_desc_nperblock_hperblock_wperblock,
332  in_block_buf,
333  out_grid_desc_n_h_w,
334  out_global_buf,
335  I0);
336  }
337 };
338 
339 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__global__ void kernel_nd_permute(const InGridDesc in_grid_desc, const OutGridDesc out_grid_desc, const InDataType *p_in_global, OutDataType *p_out_global, const ElementwiseOperation elementwise_op, const Block2TileMap block_2_tile_map)
Definition: gridwise_permute.hpp:25
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: gridwise_permute.hpp:76
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: gridwise_permute.hpp:108
Block2TileMap(const InGridDesc &desc)
Definition: gridwise_permute.hpp:91
constexpr __host__ index_t CalculateGridSize(const InGridDesc &desc) const
Definition: gridwise_permute.hpp:93
static constexpr index_t NumDim
Definition: gridwise_permute.hpp:77
Block2TileMap & operator=(const Block2TileMap &)=delete
static constexpr auto I0
Definition: gridwise_permute.hpp:80
Block2TileMap(Block2TileMap &&)=delete
Block2TileMap & operator=(Block2TileMap &&)=delete
Block2TileMap(const Block2TileMap &)=default
Definition: gridwise_permute.hpp:60
__host__ static constexpr __device__ auto MakeDefaultBlock2TileMap(const InGridDesc &desc)
Definition: gridwise_permute.hpp:177
static constexpr auto I2
Definition: gridwise_permute.hpp:71
static constexpr auto I0
Definition: gridwise_permute.hpp:69
__host__ static constexpr __device__ auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock()
Definition: gridwise_permute.hpp:137
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_permute.hpp:73
__host__ static constexpr __device__ auto GetMergedDesc(const GridDesc &desc)
Definition: gridwise_permute.hpp:150
static constexpr auto I1
Definition: gridwise_permute.hpp:70
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_permute.hpp:168
__host__ static constexpr __device__ bool CheckValidity(const InGridDesc &in_grid_desc, const OutGridDesc &out_grid_desc)
Definition: gridwise_permute.hpp:182
static __device__ void Run(const InGridDesc in_grid_desc, const OutGridDesc out_grid_desc, const InDataType *p_in_global, OutDataType *p_out_global, void *__restrict__ p_shared, const ElementwiseOperation elementwise_op, const Block2TileMap &block_2_tile_map)
Definition: gridwise_permute.hpp:204
Definition: multi_index_transform.hpp:13
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:143
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334