/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/tensor.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/tensor.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/tensor.hpp Source File
tensor.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "utils/tensor_utils.hpp"
8 #include "utils/layout_utils.hpp"
9 
10 // Disable from doxygen docs generation
12 namespace ck {
13 namespace wrapper {
15 
16 // Disable from doxygen docs generation
18 namespace {
19 namespace detail {
25 template <typename T>
26 __host__ __device__ constexpr bool HasSlice(T&&)
27 {
29 }
30 template <typename... Ts>
31 __host__ __device__ constexpr bool HasSlice(Tuple<Ts...>&&)
32 {
33  return (HasSlice(Ts{}) || ...);
34 }
35 
43 template <typename... Ts, typename SlicedShape>
44 __host__ __device__ constexpr auto GetSlicedShape(const Tuple<Ts...>& idxs,
45  const SlicedShape& shape)
46 {
47  // Pack each value in tuple to remove empty tuples after generation
48  auto new_shape = generate_tuple(
49  [&](auto i) {
50  constexpr auto num_i = Number<i>{};
51  if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
52  {
53  if constexpr(!detail::HasSlice(tuple_element_t<i.value, Tuple<Ts...>>{}))
54  {
55  // if tuple does not have any slice then we can remove dimension
56  return Tuple<>{};
57  }
58  else
59  {
60  // if tuple then recurrence
61  return make_tuple(GetSlicedShape(idxs.At(num_i), shape.At(num_i)));
62  }
63  }
64  else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
65  {
66  // calculate new dimension
67  const auto& dim = size(shape.At(num_i));
68  const auto val = idxs.At(num_i).range(dim);
69  return make_tuple(val);
70  }
71  else
72  {
73  // remove dimension for just value
74  return Tuple<>{};
75  }
76  },
77  Number<Tuple<Ts...>::Size()>{});
78  // Remove empty tuples (deleted elements) and return
79  return UnrollNestedTuple<0, 1>(new_shape);
80 }
81 
89 template <typename T, typename Shape>
90 __host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, const Shape& shape)
91 {
92  const auto unrolled_shape = UnrollNestedTuple(shape);
93  return generate_tuple(
94  [&](auto i) {
95  // dimension offset from idx
96  const auto dim = unrolled_shape.At(Number<i>{});
97  const auto dim_idx = idx % dim;
98  idx /= dim;
99  return make_freeze_transform(dim_idx);
100  },
101  Number<decltype(unrolled_shape)::Size()>{});
102 }
103 
111 template <typename... Ts, typename Shape>
112 __host__ __device__ constexpr auto GenerateSliceTransforms(const Tuple<Ts...>& idx,
113  const Shape& shape)
114 {
115  // Pack each value in tuple to remove empty tuples after generation
116  auto transforms = generate_tuple(
117  [&](auto i) {
118  constexpr auto num_i = Number<i>{};
119  if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
120  {
121  return GenerateSliceTransforms(idx.At(num_i), shape.At(num_i));
122  }
123  else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
124  {
125 
126  const auto from = idx.At(num_i).from_;
127  const auto dim = size<num_i>(shape);
128  const auto range = idx.At(num_i).range(dim);
129  return make_slice_transform(range, from, from + range);
130  }
131  else
132  {
133  // remove dimension for just value
134  return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i));
135  }
136  },
137  Number<Tuple<Ts...>::Size()>{});
138  // Remove empty tuples (deleted elements) and return
139  return UnrollNestedTuple(transforms);
140 }
141 
142 template <index_t i, typename LowerIndex>
143 __host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze<LowerIndex>&)
144 {
145  // There is no output for Freeze transform
146  return Sequence<>{};
147 }
148 
149 template <index_t i, typename LowLength, typename SliceBegin, typename SliceEnd>
150 __host__ __device__ constexpr auto GetSequenceVal(const ck::Slice<LowLength, SliceBegin, SliceEnd>&)
151 {
152  return Sequence<i>{};
153 }
154 
155 template <index_t i>
156 __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&)
157 {
158  return Tuple<>{};
159 }
160 
161 template <index_t i, typename... Transforms>
162 __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>& transforms)
163 {
164  constexpr auto num_transforms = Tuple<Transforms...>::Size();
165  // Deduce Sequence element for specific transform
166  const auto current_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
167  if constexpr(is_same_v<decltype(current_elem), const Sequence<>>)
168  {
169  const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
170  return concat_tuple(make_tuple(current_elem), next_tuple);
171  }
172  else
173  {
174  // Increase i if current_elem is Slice transform
175  const auto next_tuple = GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
176  return concat_tuple(make_tuple(current_elem), next_tuple);
177  }
178 }
179 
180 template <typename... Ts, typename Shape, typename UnrolledDescriptor>
181 __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
182  const Shape& shape,
183  const UnrolledDescriptor& flatten_desc)
184 {
185  constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
186 
187  const auto transforms = GenerateSliceTransforms(idx, shape);
188  using TransformsTupleType = decltype(transforms);
189 
190  const auto lower_dims =
191  generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
192  const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
193  return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
194 }
195 } // namespace detail
196 } // namespace
198 
209 template <MemoryTypeEnum BufferAddressSpace,
210  typename ElementType,
211  typename Shape,
212  typename UnrolledDescriptorType>
213 struct Tensor
214 {
215  public:
217  Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
220  ElementType,
221  typename scalar_type<std::remove_const_t<ElementType>>::type>; // DataType
222 
223  static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
224  static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
225  BufferAddressSpace == MemoryTypeEnum ::Vgpr);
226 
227  __host__ __device__ Tensor() = delete;
228  __host__ __device__ constexpr Tensor(ElementType* pointer,
230  : layout_(layout),
231  buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize())),
232  multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
233  base_offset_(0)
234  {
235  static_assert(IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
236  }
237 
238  __host__ __device__ constexpr Tensor(const Layout<Shape, UnrolledDescriptorType>& layout)
239  : layout_(layout),
240  multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
241  base_offset_(0)
242  {
243  static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
244  }
245 
246  __host__ __device__ constexpr const Layout<Shape, UnrolledDescriptorType>& GetLayout() const
247  {
248  return layout_;
249  }
250 
257  template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
258  __host__ __device__ auto operator[](const Tuple<Ts...>& idx)
259  {
260  static_assert(IsDynamicBuffer, "Register slice is not supported");
261  const auto& shape = layout_.GetShape();
262  auto new_shape = detail::GetSlicedShape(idx, shape);
263 
264  const auto& flatten_desc = layout_.GetUnrolledDescriptor();
265  auto new_desc = detail::GenerateSlicedDescriptor(idx, shape, flatten_desc);
266  const auto new_layout =
268  // Update embed offset
269  base_offset_ -= new_layout(make_tuple(Number<0>{}));
270  return make_tensor<BufferAddressSpace>(buffer_.p_data_, new_layout);
271  }
272 
273  template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
274  __host__ __device__ auto operator()(const Tuple<Ts...>& idx)
275  {
276  return this->operator[](idx);
277  }
278 
279  template <typename... Idxs, enable_if_t<detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
280  __host__ __device__ auto operator()(Idxs... idxs)
281  {
282  return this->operator[](make_tuple(idxs...));
283  }
284 
291  template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
292  __host__ __device__ const TensorElementType& operator[](const Tuple<Ts...>& idx) const
293  {
294  if constexpr(IsDynamicBuffer)
295  {
296  const index_t offset = layout_(idx) + base_offset_;
297  return buffer_[offset];
298  }
299  else
300  {
301  constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
302  Shape{},
303  UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
304  // Calculate and apply base offset in compile-time
305  constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
306  Shape{},
307  UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
308  return buffer_[Number<index_offset + base_offset>{}];
309  }
310  }
311 
312  template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
313  __host__ __device__ const TensorElementType& operator()(const Tuple<Ts...>& idx) const
314  {
315  return this->operator[](idx);
316  }
317 
318  template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
319  __host__ __device__ const TensorElementType& operator()(Idxs... idxs) const
320  {
321  return this->operator[](make_tuple(idxs...));
322  }
323 
330  template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
331  __host__ __device__ TensorElementType& operator[](const Tuple<Ts...>& idx)
332  {
333  if constexpr(IsDynamicBuffer)
334  {
335  const index_t offset = layout_(idx) + base_offset_;
336  return buffer_(offset);
337  }
338  else
339  {
340  constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
341  Shape{},
342  UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
343  // Apply embed offset (calculate in compiletime)
344  constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
345  Shape{},
346  UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
347  return buffer_(Number<index_offset + base_offset>{});
348  }
349  }
350 
351  template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
352  __host__ __device__ TensorElementType& operator()(const Tuple<Ts...>& idx)
353  {
354  return this->operator[](idx);
355  }
356 
357  template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
358  __host__ __device__ TensorElementType& operator()(Idxs... idxs)
359  {
360  return this->operator[](make_tuple(idxs...));
361  }
362 
368  __host__ __device__ constexpr auto GetMergedNestingDescriptor()
369  {
370  return layout_.GetMergedNestingDescriptor();
371  }
372 
378  __host__ __device__ TensorElementType* GetPointer() const { return buffer_.p_data_; }
379 
380  __host__ __device__ constexpr auto& GetBuffer() { return buffer_; }
381  __host__ __device__ constexpr auto& GetBuffer() const { return buffer_; }
382 
388  __host__ __device__ constexpr auto& GetMultiIdxOffsets() const { return multi_idx_offset_; }
389 
395  template <typename MultiIdxOffsets>
396  __host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset)
397  {
398  multi_idx_offset_ = multi_idx_offset;
399  base_offset_ += layout_(multi_idx_offset);
400  }
401 
402  private:
403  // Disable from doxygen docs generation
405  using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
406  ElementType,
407  ElementSpaceSize,
408  true /*InvalidElementUseNumericalZeroValue*/>;
409  using StaticBufferType = std::conditional_t<
411  StaticBuffer<BufferAddressSpace,
412  ElementType,
413  size(Shape{}),
414  true /*InvalidElementUseNumericalZeroValue*/>,
415  StaticBufferTupleOfVector<BufferAddressSpace,
416  TensorElementType,
417  size(Shape{}) /
418  scalar_type<std::remove_const_t<ElementType>>::vector_size,
419  scalar_type<std::remove_const_t<ElementType>>::vector_size,
420  true /*InvalidElementUseNumericalZeroValue*/>>;
421  // If register use static buffer, else use dynamic buffer
422  using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
423 
425  Buffer buffer_;
426  // We use multi_idx_offset_ to enable the creation of a descriptor in
427  // compile time for partitions or tiles if tile shape and thread layout
428  // is known at compile time (We can use the same descriptor for each
429  // thread). Additionally, the copy between the static and dynamic buffer
430  // requires a descriptor known at compile time, so we can shift data using
431  // such multi_idx_offset_.
432  MultiIndex<Shape::Size()> multi_idx_offset_;
433  // Base offset and multi index offset are corresponding to exactly the
434  // same element in tensor ( and in physical memory ). Multi index offset
435  // is multi dimensional index. However base offset is calculated using
436  // tensor descriptor (thus all it's transforms) and is linear (1D).
437  // We store base_offset_ to avoid multiple recalculations.
438  index_t base_offset_;
440 };
441 
442 } // namespace wrapper
443 } // namespace ck
__host__ constexpr __device__ const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition: layout_utils.hpp:431
Definition: ck.hpp:267
__host__ constexpr __device__ auto concat_tuple(const Tuple< X... > &tx, const Tuple< Y... > &ty)
Definition: tuple_helper.hpp:52
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto UnrollNestedTuple(const Tuple<> &element)
Definition: tuple_helper.hpp:120
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:110
constexpr bool is_same_v
Definition: type.hpp:283
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: is_detected.hpp:34
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
__host__ constexpr __device__ auto make_zero_multi_index()
Definition: array_multi_index.hpp:21
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition: tuple_helper.hpp:176
Array< index_t, N > MultiIndex
Definition: array_multi_index.hpp:12
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:461
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
integral_constant< index_t, N > Number
Definition: number.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
const GenericPointer< typename T::ValueType > & pointer
Definition: pointer.h:1249
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
__host__ constexpr __device__ auto GetElementSpaceSize() const
Definition: layout.hpp:297
Tensor wrapper that performs static and dynamic buffer logic. The tensor is based on a descriptor sto...
Definition: host_tensor.hpp:277
__host__ constexpr __device__ auto & GetBuffer() const
Definition: tensor.hpp:381
__host__ __device__ TensorElementType * GetPointer() const
Get pointer to the data.
Definition: tensor.hpp:378
__host__ constexpr __device__ const Layout< Shape, UnrolledDescriptorType > & GetLayout() const
Definition: tensor.hpp:246
__host__ constexpr __device__ void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset)
Apply multi index offset on the tensor.
Definition: tensor.hpp:396
__host__ constexpr __device__ auto & GetMultiIdxOffsets() const
Get multi index offset to the data.
Definition: tensor.hpp:388
std::conditional_t< is_scalar_type< ElementType >::value, ElementType, typename scalar_type< std::remove_const_t< ElementType > >::type > TensorElementType
Definition: tensor.hpp:221
__host__ constexpr __device__ auto GetMergedNestingDescriptor()
Get descriptor with all nested dimensions merged.
Definition: tensor.hpp:368
__host__ constexpr __device__ Tensor(const Layout< Shape, UnrolledDescriptorType > &layout)
Definition: tensor.hpp:238
__host__ __device__ Tensor()=delete
__host__ constexpr __device__ Tensor(ElementType *pointer, const Layout< Shape, UnrolledDescriptorType > &layout)
Definition: tensor.hpp:228
decltype(Layout< Shape, UnrolledDescriptorType >{ Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()) ElementSpaceSize
Definition: tensor.hpp:217
__host__ constexpr __device__ auto & GetBuffer()
Definition: tensor.hpp:380
Definition: multi_index_transform.hpp:1558
Definition: multi_index_transform.hpp:1776
static constexpr bool value
Definition: data_type.hpp:218
__host__ constexpr __device__ const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition: tensor_utils.hpp:162
AddressSpaceEnum MemoryTypeEnum
Memory type, allowed members:
Definition: tensor_utils.hpp:30