/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/utils/tensor_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/utils/tensor_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/utils/tensor_utils.hpp Source File
tensor_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/ck.hpp"
7 
9 #include "ck/utility/number.hpp"
10 #include "ck/utility/tuple.hpp"
15 
16 // Disable from doxygen docs generation
18 namespace ck {
19 namespace wrapper {
21 
31 
32 // Disable from doxygen docs generation
34 // forward declarations
35 template <typename Shape, typename UnrolledDescriptorType>
36 struct Layout;
37 template <MemoryTypeEnum BufferAddressSpace,
38  typename ElementType,
39  typename Shape,
40  typename UnrolledDescriptorType>
41 struct Tensor;
42 
43 template <typename FromType, typename ToType>
44 struct Slice
45 {
46  __host__ __device__ constexpr Slice() : from_(), to_() {}
47  __host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {}
48 
55  template <typename T>
56  __host__ __device__ constexpr auto range(const T& dim) const
57  {
58  if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
59  is_same_v<std::remove_const_t<T>, index_t>)
60  {
61  if(to_ < 0)
62  {
63  return dim - from_ + to_ + 1;
64  }
65  else
66  {
67  // workaround if one end of the interval is index_t and the second one is Number
68  return static_cast<index_t>(to_) - static_cast<index_t>(from_);
69  }
70  }
71  else
72  {
73  static_assert(T{} >= ToType{} && FromType{} >= Number<0>{} &&
74  (ToType{} < 0 || ToType{} > FromType{}),
75  "Invalid range");
76  if constexpr(ToType{} < 0)
77  {
78  return dim - from_ + to_ + Number<1>{};
79  }
80  else
81  {
82  return to_ - from_;
83  }
84  }
85  }
86 
87  __host__ __device__ static constexpr bool IsSlice() { return true; }
88 
89  const FromType from_;
90  const ToType to_;
91 };
92 
93 template <typename T>
94 using is_slice = decltype(std::declval<T&>().IsSlice());
95 
96 template <typename T>
97 using is_tuple = decltype(std::declval<T&>().IsTuple());
99 
108 template <MemoryTypeEnum MemoryType,
109  typename ElementType,
110  typename Shape,
111  typename UnrolledDescriptorType>
112 constexpr auto make_tensor(ElementType* pointer,
114 {
116 }
117 
125 template <MemoryTypeEnum MemoryType,
126  typename ElementType,
127  typename Shape,
128  typename UnrolledDescriptorType>
130 {
132 }
133 
139 template <MemoryTypeEnum BufferAddressSpace,
140  typename ElementType,
141  typename Shape,
142  typename UnrolledDescriptorType>
143 __host__ __device__ void
145 {
146  static_assert(
148  return tensor.GetBuffer().Clear();
149 }
150 
157 template <MemoryTypeEnum BufferAddressSpace,
158  typename ElementType,
159  typename Shape,
160  typename UnrolledDescriptorType>
161 __host__ __device__ constexpr const auto&
163 {
164  return tensor.GetLayout();
165 }
166 
174 template <index_t... Idxs,
175  MemoryTypeEnum BufferAddressSpace,
176  typename ElementType,
177  typename Shape,
178  typename UnrolledDescriptorType>
179 __host__ __device__ constexpr auto
181 {
182  return size<Idxs...>(tensor.GetLayout());
183 }
184 
192 template <index_t... Idxs,
193  MemoryTypeEnum BufferAddressSpace,
194  typename ElementType,
195  typename Shape,
196  typename UnrolledDescriptorType>
197 __host__ __device__ constexpr auto
199 {
200  return rank<Idxs...>(tensor.GetLayout());
201 }
202 
210 template <index_t... Idxs,
211  MemoryTypeEnum BufferAddressSpace,
212  typename ElementType,
213  typename Shape,
214  typename UnrolledDescriptorType>
215 __host__ __device__ constexpr auto
217 {
218  return depth<Idxs...>(tensor.GetLayout());
219 }
220 
227 template <MemoryTypeEnum BufferAddressSpace,
228  typename ElementType,
229  typename Shape,
230  typename UnrolledDescriptorType>
231 __host__ __device__ constexpr const auto&
233 {
234  return shape(tensor.GetLayout());
235 }
236 
244 template <typename FromType, typename ToType>
245 constexpr auto slice(const FromType from, const ToType to)
246 {
247  return Slice<FromType, ToType>(from, to);
248 }
249 
256 template <typename ToType>
257 constexpr auto slice(const ToType to)
258 {
259  if constexpr(is_same_v<ToType, index_t>)
260  {
261  return Slice<index_t, ToType>(0, to);
262  }
263  else
264  {
265  return Slice<Number<0>, ToType>(Number<0>{}, to);
266  }
267 }
268 
274 constexpr auto slice() { return Slice<Number<0>, Number<-1>>(Number<0>{}, Number<-1>{}); }
275 
276 } // namespace wrapper
277 } // namespace ck
Definition: ck.hpp:267
AddressSpaceEnum
Definition: amd_address_space.hpp:15
constexpr bool is_same_v
Definition: type.hpp:283
int32_t index_t
Definition: ck.hpp:298
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition: tuple_helper.hpp:176
integral_constant< index_t, N > Number
Definition: number.hpp:12
const GenericPointer< typename T::ValueType > & pointer
Definition: pointer.h:1249
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
Tensor wrapper that performs static and dynamic buffer logic. The tensor is based on a descriptor sto...
Definition: host_tensor.hpp:277
__host__ constexpr __device__ const Layout< Shape, UnrolledDescriptorType > & GetLayout() const
Definition: tensor.hpp:246
__host__ constexpr __device__ auto & GetBuffer()
Definition: tensor.hpp:380
__host__ constexpr __device__ const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition: tensor_utils.hpp:162
constexpr auto slice(const FromType from, const ToType to)
Get dim slice.
Definition: tensor_utils.hpp:245
__host__ constexpr __device__ auto depth(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Depth of Shape tuple.
Definition: tensor_utils.hpp:216
__host__ __device__ void clear(Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Clear tensor. (Only for Vpgr/Sgpr)
Definition: tensor_utils.hpp:144
__host__ constexpr __device__ auto size(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Product of tensor shape dims.
Definition: tensor_utils.hpp:180
__host__ constexpr __device__ auto rank(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Rank of Shape tuple.
Definition: tensor_utils.hpp:198
__host__ constexpr __device__ const auto & shape(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor shape.
Definition: tensor_utils.hpp:232
AddressSpaceEnum MemoryTypeEnum
Memory type, allowed members:
Definition: tensor_utils.hpp:30
constexpr auto make_register_tensor(const Layout< Shape, UnrolledDescriptorType > &layout)
Make SGPR or VGPR tensor function.
Definition: tensor_utils.hpp:129
constexpr auto make_tensor(ElementType *pointer, const Layout< Shape, UnrolledDescriptorType > &layout)
Make tensor function.
Definition: tensor_utils.hpp:112