/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/layout.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/layout.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/layout.hpp Source File
layout.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 // Disable from doxygen docs generation
10 namespace ck {
11 namespace wrapper {
13 
22 template <typename Shape, typename UnrolledDescriptorType>
23 struct Layout
24 {
25  // Disable from doxygen docs generation
27  private:
28  static constexpr auto I0 = Number<0>{};
29  static constexpr auto I1 = Number<1>{};
30 
37  template <typename... Ts>
38  __host__ __device__ constexpr static auto
39  GenerateDefaultIdxsTuple([[maybe_unused]] const Tuple<Ts...>& shape)
40  {
41  return generate_tuple(
42  [&](auto) {
43  if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
44  {
45  // runtime layout
46  return index_t(0);
47  }
48  else
49  {
50  // compiletime layout
51  return I0;
52  }
53  },
54  Number<Tuple<Ts...>::Size()>{});
55  }
56 
66  template <typename Idx, typename... Ts>
67  __host__ __device__ constexpr static auto
68  GenerateLowerDim([[maybe_unused]] const Tuple<Ts...>& shape)
69  {
70  if constexpr(Idx::value == 0)
71  {
72  if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
73  {
74  // Return Sequence for the first tuple
75  constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
76  tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
77  using LowerDimsSequence =
79  return LowerDimsSequence::Reverse();
80  }
81  else
82  {
83  // Return first element
84  return Sequence<0>{};
85  }
86  }
87  else
88  {
89  // Get previous element using recurence (in compile-time)
90  using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
91  const auto next_seq_val = PreviousSeqT::At(I0) + 1;
92  if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
93  {
94  constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
95  tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
96  using LowerDimsSequence =
97  typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
98  type;
99  return LowerDimsSequence::Reverse();
100  }
101  else
102  {
103  return Sequence<next_seq_val>{};
104  }
105  }
106  }
107 
119  template <typename... ShapeDims, typename... IdxDims>
120  __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape,
121  const Tuple<IdxDims...>& idx)
122  {
123  if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
124  {
125  // Index unrolled to flatten, return shape
126  return shape;
127  }
128  else
129  {
130  // Iterate over shape tuple elements:
131  // 1. If corresponding idx element is tuple then return (will be unrolled)
132  // 2. If no, pack in tuple. It will be restored during unroll.
133  auto aligned_shape = generate_tuple(
134  [&](auto i) {
135  if constexpr(is_detected<is_tuple,
136  tuple_element_t<i, Tuple<IdxDims...>>>::value)
137  {
138  return shape.At(i);
139  }
140  else
141  {
142  return make_tuple(shape.At(i));
143  }
144  },
145  Number<Tuple<IdxDims...>::Size()>{});
146 
147  // Unroll and process next step
148  return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape),
149  UnrollNestedTuple<0, 1>(idx));
150  }
151  }
152 
160  template <typename... ShapeDims, typename DescriptorToMerge>
161  __host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
162  const DescriptorToMerge& desc)
163  {
164  // Reverse each element in tuple
165  const auto merge_elems = TupleReverse(UnrollNestedTuple(shape));
166  // Generate reverted indexes (column major traverse)
167  using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type;
168  const auto lower_dims = make_tuple(MergeElemsSequence::Reverse());
169  const auto upper_dims = make_tuple(Sequence<0>{});
170  // Merge to 1d
171  if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
172  {
174  desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
175  }
176  else
177  {
178  // If the descriptor is known at the compilation time,
179  // use `make_merge_transform_v1_carry_check` because it doesn't use
180  // memcpy.
182  desc,
184  lower_dims,
185  upper_dims);
186  }
187  }
188 
201  template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
202  __host__ __device__ constexpr static auto
203  CreateMergedDescriptor(const Tuple<ShapeDims...>& shape,
204  [[maybe_unused]] const Tuple<IdxDims...>& idxs,
205  DescriptorToMerge& desc)
206  {
207  const auto transforms = generate_tuple(
208  [&](auto i) {
209  // Compare Idx with shape
210  if constexpr(is_detected<is_tuple,
211  tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
212  !is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value)
213  {
214  // If shape element is tuple and idx element is Number, then merge
215  // Unroll and reverse tuple to traverse column-major
216  const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
217  if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
218  {
219  return make_merge_transform(merge_elems);
220  }
221  else
222  {
223  // If the descriptor is known at the compilation time,
224  // use `make_merge_transform_v1_carry_check` because
225  // it doesn't use memcpy.
226  return make_merge_transform_v1_carry_check(merge_elems);
227  }
228  }
229  else
230  {
231  // If shape element is integer and idx element is tuple, passed idx is wrong
232  static_assert(
233  !(!is_detected<is_tuple, tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
234  is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value),
235  "Wrong Idx for layout()");
236  // If shape element has the same type as idx element, then pass through
237  return make_pass_through_transform(shape.At(i));
238  }
239  },
240  Number<Tuple<ShapeDims...>::Size()>{});
241 
242  const auto lower_dims =
243  generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
244  Number<Tuple<ShapeDims...>::Size()>{});
245  const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
246  Number<Tuple<ShapeDims...>::Size()>{});
247 
248  return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
249  }
250 
251  using Descriptor1dType =
252  remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>;
253  using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
255 
256  public:
257  using LayoutShape = Shape;
258  using LayoutUnrolledDescriptorType = UnrolledDescriptorType;
259 
268  template <typename... ShapeDims, typename... IdxDims>
269  __host__ __device__ constexpr static auto
270  TransformDesc(const Tuple<ShapeDims...>& shape,
271  const Tuple<IdxDims...>& idxs,
272  const UnrolledDescriptorType& naive_descriptor)
273  {
274  if constexpr(Tuple<IdxDims...>::Size() == I1)
275  {
276  // 1d idx path
277  return MakeMerge1d(shape, naive_descriptor);
278  }
279  else
280  {
281  // Merge nested shape dims
282  // Example idx: (1, 1), 1, 1
283  // Example shape: (2, (2, 2)), 2, (2, 2)
284  // Merged shape: (2, 4), 2, 4
285  static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
286  "Idx rank and Shape rank must be the same (except 1d).");
287  // Unroll while IdxDims is nested
288  const auto aligned_shape = AlignShapeToIdx(shape, idxs);
289  // Transform correct form of shape
290  return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idxs), naive_descriptor);
291  }
292  }
293 
294  using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
295  Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))>;
296 
297  __host__ __device__ constexpr auto GetElementSpaceSize() const
298  {
299  return unrolled_descriptor_.GetElementSpaceSize();
300  }
301 
302  __host__ __device__ Layout() = delete;
303 
310  __host__ __device__ constexpr Layout(const Shape& shape,
311  const UnrolledDescriptorType& unnested_descriptor)
312  : unrolled_descriptor_(unnested_descriptor), shape_(shape)
313  {
314  // Construct if runtime mode
315  if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
316  {
317  descriptor_1d_ = MakeMerge1d(shape_, unrolled_descriptor_);
318  merged_nests_descriptor_ =
319  TransformDesc(shape_, DefaultIdxsTupleType{}, unrolled_descriptor_);
320  }
321  }
322 
329  template <typename Idxs>
330  __host__ __device__ constexpr index_t operator()() const
331  {
332  static_assert(remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime(),
333  "Compiletime operator used on runtime layout.");
334  using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnrolledDescriptorType{}));
335  using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
336  return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
337  }
338 
345  template <typename... Ts>
346  __host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
347  {
348  if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == 1)
349  {
350  // if 1d access
351  return descriptor_1d_.CalculateOffset(Idx);
352  }
353  else if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == Shape::Size())
354  {
355  // if Shape::Size() access (merged nested shapes)
356  return merged_nests_descriptor_.CalculateOffset(UnrollNestedTuple(Idx));
357  }
358  else
359  {
360  // Custom index, need to transform descriptor
361  const auto transformed_desc = TransformDesc(shape_, Idx, unrolled_descriptor_);
362  return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
363  }
364  }
365 
372  template <index_t IDim>
373  __host__ __device__ constexpr auto GetLength() const
374  {
375  const auto elem = shape_.At(Number<IDim>{});
376  if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
377  {
378  const auto unrolled_element = UnrollNestedTuple(elem);
379  return TupleReduce<I0.value, unrolled_element.Size()>(
380  [](auto x, auto y) { return x * y; }, unrolled_element);
381  }
382  else
383  {
384  return elem;
385  }
386  }
387 
393  __host__ __device__ constexpr auto GetLengths() const
394  {
395  const auto unrolled_shape = UnrollNestedTuple(shape_);
396  return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
397  unrolled_shape);
398  }
399 
405  __host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
406 
412  __host__ __device__ constexpr auto GetDefaultLengthsTuple() const
413  {
414  return generate_tuple([&](auto i) { return GetLength<i>(); }, Number<Shape::Size()>{});
415  }
416 
422  __host__ __device__ constexpr auto GetDefaultStartIdxs() const
423  {
424  return GenerateDefaultIdxsTuple(shape_);
425  }
426 
436  __host__ __device__ constexpr const MergedNestsDescriptorType&
438  {
439  return merged_nests_descriptor_;
440  }
441 
449  __host__ __device__ constexpr const Descriptor1dType& Get1DDescriptor() const
450  {
451  return descriptor_1d_;
452  }
453 
461  __host__ __device__ constexpr const UnrolledDescriptorType& GetUnrolledDescriptor() const
462  {
463  return unrolled_descriptor_;
464  }
465 
466  // Disable from doxygen docs generation
468  private:
469  // All dimensions are unrolled
470  UnrolledDescriptorType unrolled_descriptor_;
471  // 1D descriptor
472  Descriptor1dType descriptor_1d_;
473  // All nesting are merged
474  MergedNestsDescriptorType merged_nests_descriptor_;
475  // Example, shape: ((2, 2), 2)
476  // UnrolledDescriptorType lengths: (2, 2, 2)
477  // Descriptor1dType lengths: (8)
478  // MergedNestsDescriptorType lengths: (4, 2)
479  const Shape shape_;
481 };
482 
483 } // namespace wrapper
484 } // namespace ck
__host__ constexpr __device__ const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition: layout_utils.hpp:431
Definition: ck.hpp:267
__host__ constexpr __device__ auto TupleReduce(F &&f, const Tuple< Ts... > &tuple)
Definition: tuple_helper.hpp:161
__host__ constexpr __device__ auto IsNestedTuple(const Tuple< Ts... > &)
Definition: tuple_helper.hpp:180
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto make_merge_transform_v1_carry_check(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:66
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto TupleReverse(const Tuple< Ts... > &tuple)
Definition: tuple_helper.hpp:149
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto UnrollNestedTuple(const Tuple<> &element)
Definition: tuple_helper.hpp:120
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: is_detected.hpp:34
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition: tuple_helper.hpp:176
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
integral_constant< index_t, N > Number
Definition: number.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
__host__ __device__ Layout()=delete
Shape LayoutShape
Definition: layout.hpp:257
__host__ constexpr __device__ auto GetElementSpaceSize() const
Definition: layout.hpp:297
__host__ constexpr static __device__ auto TransformDesc(const Tuple< ShapeDims... > &shape, const Tuple< IdxDims... > &idxs, const UnrolledDescriptorType &naive_descriptor)
Transform descriptor to align to passed indexes.
Definition: layout.hpp:270
__host__ constexpr __device__ auto GetLength() const
Length getter (product if tuple).
Definition: layout.hpp:373
remove_cvref_t< decltype(TransformDesc(Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))> MergedNestsDescriptorType
Definition: layout.hpp:295
__host__ constexpr __device__ const Shape & GetShape() const
Shape getter.
Definition: layout.hpp:405
__host__ constexpr __device__ const Descriptor1dType & Get1DDescriptor() const
Get descriptor with all dimensions are merged (1D). Example, shape: ((2, 2), 2) Descriptor lengths: (...
Definition: layout.hpp:449
__host__ constexpr __device__ auto GetLengths() const
Layout size getter (product of shape).
Definition: layout.hpp:393
UnrolledDescriptorType LayoutUnrolledDescriptorType
Definition: layout.hpp:258
__host__ constexpr __device__ auto GetDefaultStartIdxs() const
Get default start idx (tuple filled with 0s of the same size as Shape).
Definition: layout.hpp:422
__host__ constexpr __device__ const MergedNestsDescriptorType & GetMergedNestingDescriptor() const
Get descriptor with all nested dimensions merged. Example, shape: ((2, 2), 2) Descriptor lengths: (4,...
Definition: layout.hpp:437
__host__ constexpr __device__ const UnrolledDescriptorType & GetUnrolledDescriptor() const
Get unnested descriptor (with unrolled dims) Example, shape: ((2, 2), 2) Descriptor lengths: (2,...
Definition: layout.hpp:461
__host__ constexpr __device__ auto GetDefaultLengthsTuple() const
Get default lengths (tuple filled with Shape length elements).
Definition: layout.hpp:412
__host__ __device__ index_t operator()(const Tuple< Ts... > &Idx) const
Returns real offset to element in compile time.
Definition: layout.hpp:346
__host__ constexpr __device__ Layout(const Shape &shape, const UnrolledDescriptorType &unnested_descriptor)
Layout constructor.
Definition: layout.hpp:310
__host__ constexpr __device__ index_t operator()() const
Returns real offset to element in runtime.
Definition: layout.hpp:330
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:271