/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/utils/tensor_partition.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/utils/tensor_partition.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/utils/tensor_partition.hpp Source File
tensor_partition.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "tensor_utils.hpp"
7 #include "layout_utils.hpp"
8 
11 
12 // Disable from doxygen docs generation
14 namespace ck {
15 namespace wrapper {
17 
18 // Disable from doxygen docs generation
20 namespace {
21 
22 namespace detail {
23 
32 template <typename... Ts, typename... Ls>
33 __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts...>& shape,
34  const Tuple<Ls...>& thread_lengths)
35 {
36  static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
37  return generate_tuple(
38  [&](auto i) {
39  constexpr auto num_i = Number<i>{};
40  const auto slice_len =
41  ck::math::integer_divide_ceil(size<num_i>(shape), thread_lengths.At(num_i));
42  return slice_len;
43  },
44  Number<Tuple<Ls...>::Size()>{});
45 }
46 
56 template <typename MultiIndex, typename ProjectionTuple>
57 __host__ __device__ constexpr auto
58 ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
59  [[maybe_unused]] const ProjectionTuple& projection)
60 {
61  if constexpr(is_same_v<ProjectionTuple, Tuple<>>)
62  {
63  return Tuple<>{};
64  }
65  else
66  {
67  auto base_tuple_after_projection = generate_tuple(
68  [&](auto i) {
69  const auto i_num = Number<i.value>{};
70  static_assert(
71  is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value ||
72  is_same_v<tuple_element_t<i_num, ProjectionTuple>, Number<1>>);
73  if constexpr(is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value)
74  {
75  // When slice (to remove), then insert empty tuple (will be removed in next
76  // step).
77  return Tuple<>{};
78  }
79  else
80  {
81  return make_tuple(base_tuple.At(i_num));
82  }
83  },
85  // Remove empty tuples
86  return UnrollNestedTuple<0, 1>(base_tuple_after_projection);
87  }
88 }
89 
99 template <typename... Ts, typename... Ps>
100 __host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts...>& shape,
101  const Tuple<Ps...>& projection)
102 {
103  return generate_tuple(
104  [&](auto i) {
105  if constexpr(is_detected<is_slice, tuple_element_t<i, Tuple<Ps...>>>::value)
106  {
107  return size<i>(projection).to_;
108  }
109  else
110  {
111  // number of shape element in actual fragment of shape and projection (method to
112  // calculate shape idx)
113  constexpr index_t shape_i =
114  detail::ApplyProjection(TupleSlice<0, i>(Tuple<Ts...>{}),
115  TupleSlice<0, i>(Tuple<Ps...>{}))
116  .Size();
117  return size<shape_i>(shape);
118  }
119  },
120  Number<Tuple<Ps...>::Size()>{});
121 }
122 
130 template <typename... Ts, typename... Ls, typename... Ps>
131 __host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
132  const Tuple<Ls...>& tile_shape)
133 {
134  return generate_tuple(
135  [&](auto i) { return ck::math::integer_divide_ceil(size<i>(shape), size<i>(tile_shape)); },
136  Number<Tuple<Ls...>::Size()>{});
137 }
138 
147 template <typename ThreadIdxs, typename PartitionLengthsSeq, typename OldOffsetIdxs>
148 __host__ __device__ constexpr auto
149 CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
150  const PartitionLengthsSeq& partition_lengths_seq,
151  const OldOffsetIdxs& old_offset_idxs)
152 {
153  return thread_idxs * partition_lengths_seq + old_offset_idxs;
154 }
155 
162 template <typename BlockIdxs>
163 __host__ __device__ constexpr auto GetDimsToPartition([[maybe_unused]] const BlockIdxs& block_idxs)
164 {
165  const auto dims_to_partition = generate_tuple(
166  [&](auto i) {
167  if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
168  {
169  return Number<i>{};
170  }
171  else
172  {
173  return Tuple<>{};
174  }
175  },
176  Number<BlockIdxs::Size()>{});
177  // Remove empty tuples
178  return UnrollNestedTuple<0, 1>(dims_to_partition);
179 }
180 
187 template <typename BlockIdxs>
188 __host__ __device__ constexpr auto ReplaceSlicesWithZeros(const BlockIdxs& block_idxs)
189 {
190  return generate_tuple(
191  [&](auto i) {
192  if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
193  {
194  return block_idxs.At(i);
195  }
196  else
197  {
198  return Number<0>{};
199  }
200  },
201  Number<BlockIdxs::Size()>{});
202 }
203 
210 template <typename TileShape>
211 __host__ __device__ constexpr auto
212 GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
213 {
214  return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{});
215 }
216 
224 template <typename ThreadShape, typename ThreadUnrolledDesc>
225 __host__ __device__ constexpr auto CalculateThreadMultiIdx(
226  [[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
227  const index_t thread_id)
228 {
229  static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1,
230  "Thread layout should not be transformed.");
231  constexpr auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{});
232  constexpr auto shape = ThreadShape{};
233  constexpr auto strides = embed_transform.coefficients_;
234 
235  return generate_tuple(
236  [&](auto i) {
237  constexpr auto num_i = Number<i>{};
238  return (thread_id / strides.At(num_i)) % shape.At(num_i);
239  },
240  Number<ThreadShape::Size()>{});
241 }
242 } // namespace detail
243 } // namespace
245 
258 template <typename TensorType,
259  typename ThreadShape,
260  typename ThreadUnrolledDesc,
261  typename ProjectionTuple>
262 __host__ __device__ constexpr auto
263 make_local_partition(TensorType& tensor,
264  [[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
265  const index_t thread_id,
266  const ProjectionTuple& projection)
267 {
268  static_assert(!IsNestedTuple(ThreadShape{}));
269  // Calculate new partition shape
270  const auto& tensor_shape = shape(tensor);
271  // Calculate projected thread lengths
272  constexpr auto projected_thread_lengths =
273  detail::ApplyProjection(ThreadShape{}, ProjectionTuple{});
274  constexpr auto partition_shape =
275  detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
276  constexpr auto partition_shape_seq =
277  generate_sequence_v2([&](auto I) { return size<I>(partition_shape); },
278  Number<decltype(partition_shape)::Size()>{});
279  // Calculate thread idxs and offsets
280  const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id);
281  // Apply projection on thread idxs to remove not needed idxs
282  const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
283  const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
284  projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
285  // Create new layout and tensor
286  auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
287  // Slice descriptor
288  const auto transforms = generate_tuple(
289  [&](auto i) {
290  return make_slice_transform(partition_shape.At(i),
291  offset_multi_idxs.At(i),
292  partition_shape.At(i) + offset_multi_idxs.At(i));
293  },
294  Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
295  const auto lower_upper_dims =
296  generate_tuple([&](auto i) { return Sequence<i.value>{}; },
297  Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
298  auto sliced_desc =
299  transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims);
300  // Create layout
301  const auto partition_layout =
302  Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
303  partition_shape, sliced_desc);
304  auto partition_tensor =
305  make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
306  // Apply offsets
307  return partition_tensor;
308 }
309 
319 template <typename TensorType, typename ThreadShape, typename ThreadUnrolledDesc>
320 __host__ __device__ constexpr auto
321 make_local_partition(TensorType& tensor,
322  const Layout<ThreadShape, ThreadUnrolledDesc>& thread_lengths,
323  const index_t thread_id)
324 {
325  const auto projection = detail::GenerateDefaultProjection(ThreadShape{});
326  return make_local_partition(tensor, thread_lengths, thread_id, projection);
327 }
328 
346 template <typename TensorType,
347  typename BlockShapeTuple,
348  typename BlockIdxs,
349  typename ProjectionTuple>
350 __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
351  const BlockShapeTuple& tile_shape,
352  const BlockIdxs& block_idxs,
353  const ProjectionTuple& projection)
354 {
355  static_assert(!IsNestedTuple(BlockShapeTuple{}));
356  static_assert(!IsNestedTuple(BlockIdxs{}));
357 
358  constexpr auto I0 = Number<0>{};
359  constexpr auto I1 = Number<1>{};
360  constexpr auto I2 = Number<2>{};
361 
362  auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
363 
364  constexpr auto projected_tile_shape =
365  detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
366  // Number of dims which are partitioned
367  constexpr auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{});
368  const auto parsed_block_idxs = detail::ReplaceSlicesWithZeros(block_idxs);
369  if constexpr(decltype(dims_to_partition)::Size() == I2)
370  {
371  const auto shape_with_projection_dims =
372  detail::CalculateShapeWithProjection(shape(tensor), projection);
373  // Set Value for M, N partition
374  const auto M = shape_with_projection_dims.At(dims_to_partition.At(I0));
375  const auto N = shape_with_projection_dims.At(dims_to_partition.At(I1));
376  constexpr auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0));
377  constexpr auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1));
378  auto m_n_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N));
379  // Get 1D block id
380  const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape);
381  const auto block_lengths_desc = make_naive_tensor_descriptor_packed(grid_size);
382  const index_t block_id_1d = block_lengths_desc.CalculateOffset(parsed_block_idxs);
383  // Optimized version for 2d tile shape [MxN]
384  const auto block_2_tile_map =
385  BlockToCTileMap_M00_N0_M01Adapt<MPerBlock,
386  NPerBlock,
387  remove_cvref_t<decltype(m_n_desc)>>(m_n_desc);
388  const auto block_work_idx =
389  block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id_1d));
390  const index_t m_block_data_idx_on_grid =
391  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
392  const index_t n_block_data_idx_on_grid =
393  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
394  // Apply 0 for non partitioned dims
395  const auto offset_multi_idxs = generate_tuple(
396  [&](auto i) {
397  if constexpr(i == dims_to_partition.At(I0))
398  {
399  return m_block_data_idx_on_grid;
400  }
401  else if constexpr(i == dims_to_partition.At(I1))
402  {
403  return n_block_data_idx_on_grid;
404  }
405  else
406  {
407  return Number<0>{};
408  }
409  },
410  Number<BlockShapeTuple::Size()>{});
411  const auto projected_offset_multi_idxs =
412  detail::ApplyProjection(offset_multi_idxs, projection);
413  // Create new layout and tensor
414  const auto tile_layout =
415  Layout<remove_reference_t<decltype(projected_tile_shape)>, decltype(aligned_desc)>(
416  projected_tile_shape, aligned_desc);
417  auto tile_tensor =
418  make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
419  // Apply offsets
420  tile_tensor.SetMultiIdxOffset(to_multi_index(projected_offset_multi_idxs));
421  return tile_tensor;
422  }
423  else
424  {
425  // Calculate offsets
426  // Sequence with data to process per block
427  using ProjectedTileShapeTuple = decltype(projected_tile_shape);
428  constexpr auto projected_tile_shape_seq =
429  generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); },
430  Number<ProjectedTileShapeTuple::Size()>{});
431  // Tuple with number of blocks
432  const auto projected_block_idxs =
433  to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection));
434  const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
435  projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
436  // Create new layout and tensor
437  const auto tile_layout =
439  projected_tile_shape, aligned_desc);
440  auto tile_tensor =
441  make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
442  // Apply offsets
443  tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
444  return tile_tensor;
445  }
446 }
447 
460 template <typename TensorType, typename BlockShapeTuple, typename BlockIdxs>
461 __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
462  const BlockShapeTuple& tile_shape,
463  const BlockIdxs& block_idxs)
464 {
465  const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
466  return make_local_tile(tensor, tile_shape, block_idxs, projection);
467 }
468 
469 } // namespace wrapper
470 } // namespace ck
__host__ constexpr __device__ const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition: layout_utils.hpp:431
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
__host__ constexpr __device__ auto IsNestedTuple(const Tuple< Ts... > &)
Definition: tuple_helper.hpp:180
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto to_multi_index(const T &x)
Definition: array_multi_index.hpp:28
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:110
constexpr bool is_same_v
Definition: type.hpp:283
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: is_detected.hpp:34
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
Array< index_t, N > MultiIndex
Definition: array_multi_index.hpp:12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
integral_constant< index_t, N > Number
Definition: number.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
__host__ static constexpr __device__ index_t Size()
Definition: array.hpp:20
__host__ constexpr __device__ auto make_local_partition(TensorType &tensor, [[maybe_unused]] const Layout< ThreadShape, ThreadUnrolledDesc > &thread_layout, const index_t thread_id, const ProjectionTuple &projection)
Create local partition for thread (At now only packed partition is supported).
Definition: tensor_partition.hpp:263
__host__ constexpr __device__ auto make_local_tile(const TensorType &tensor, const BlockShapeTuple &tile_shape, const BlockIdxs &block_idxs, const ProjectionTuple &projection)
Create local tile for thread block. (At now only packed tile is supported).
Definition: tensor_partition.hpp:350
__host__ constexpr __device__ const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition: tensor_utils.hpp:162