/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/static_buffer.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/static_buffer.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/static_buffer.hpp Source File
static_buffer.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // static buffer for scalar
11 template <AddressSpaceEnum AddressSpace,
12  typename T,
13  index_t N,
14  bool InvalidElementUseNumericalZeroValue> // TODO remove this bool, no longer needed
15 struct StaticBuffer : public StaticallyIndexedArray<T, N>
16 {
17  using type = T;
19 
20  __host__ __device__ constexpr StaticBuffer() : base{} {}
21 
22  template <typename... Ys>
23  __host__ __device__ constexpr StaticBuffer& operator=(const Tuple<Ys...>& y)
24  {
25  static_assert(base::Size() == sizeof...(Ys), "wrong! size not the same");
26  StaticBuffer& x = *this;
27  static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y[i]; });
28  return x;
29  }
30 
31  __host__ __device__ constexpr StaticBuffer& operator=(const T& y)
32  {
33  StaticBuffer& x = *this;
34  static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y; });
35  return x;
36  }
37 
38  __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
39 
40  __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
41 
42  __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
43 
44  // read access
45  template <index_t I>
46  __host__ __device__ constexpr const T& operator[](Number<I> i) const
47  {
48  return base::operator[](i);
49  }
50 
51  // write access
52  template <index_t I>
53  __host__ __device__ constexpr T& operator()(Number<I> i)
54  {
55  return base::operator()(i);
56  }
57 
58  __host__ __device__ void Set(T x)
59  {
60  static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{x}; });
61  }
62 
63  __host__ __device__ void Clear() { Set(T{0}); }
64 };
65 
66 // static buffer for vector
67 template <AddressSpaceEnum AddressSpace,
68  typename S,
69  index_t NumOfVector,
70  index_t ScalarPerVector,
71  bool InvalidElementUseNumericalZeroValue, // TODO remove this bool, no longer needed,
72  typename enable_if<is_scalar_type<S>::value, bool>::type = false>
74  : public StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>
75 {
78 
79  static constexpr auto s_per_v = Number<ScalarPerVector>{};
80  static constexpr auto num_of_v_ = Number<NumOfVector>{};
81  static constexpr auto s_per_buf = s_per_v * num_of_v_;
82 
83  __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {}
84 
85  __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
86 
87  __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
88 
89  __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
90 
91  __host__ __device__ static constexpr index_t Size() { return s_per_buf; };
92 
93  // Get S
94  // i is offset of S
95  template <index_t I>
96  __host__ __device__ constexpr const S& operator[](Number<I> i) const
97  {
98  constexpr auto i_v = i / s_per_v;
99  constexpr auto i_s = i % s_per_v;
100 
101  return base::operator[](i_v).template AsType<S>()[i_s];
102  }
103 
104  // Set S
105  // i is offset of S
106  template <index_t I>
107  __host__ __device__ constexpr S& operator()(Number<I> i)
108  {
109  constexpr auto i_v = i / s_per_v;
110  constexpr auto i_s = i % s_per_v;
111 
112  return base::operator()(i_v).template AsType<S>()(i_s);
113  }
114 
115  // Get X
116  // i is offset of S, not X. i should be aligned to X
117  template <typename X,
118  index_t I,
119  typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
120  bool>::type = false>
121  __host__ __device__ constexpr auto GetAsType(Number<I> i) const
122  {
123  constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
124 
125  static_assert(s_per_v % s_per_x == 0, "wrong! V must one or multiple X");
126  static_assert(i % s_per_x == 0, "wrong!");
127 
128  constexpr auto i_v = i / s_per_v;
129  constexpr auto i_x = (i % s_per_v) / s_per_x;
130 
131  return base::operator[](i_v).template AsType<X>()[i_x];
132  }
133 
134  // Set X
135  // i is offset of S, not X. i should be aligned to X
136  template <typename X,
137  index_t I,
138  typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
139  bool>::type = false>
140  __host__ __device__ constexpr void SetAsType(Number<I> i, X x)
141  {
142  constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
143 
144  static_assert(s_per_v % s_per_x == 0, "wrong! V must contain one or multiple X");
145  static_assert(i % s_per_x == 0, "wrong!");
146 
147  constexpr auto i_v = i / s_per_v;
148  constexpr auto i_x = (i % s_per_v) / s_per_x;
149 
150  base::operator()(i_v).template AsType<X>()(i_x) = x;
151  }
152 
153  // Get read access to vector_type V
154  // i is offset of S, not V. i should be aligned to V
155  template <index_t I>
156  __host__ __device__ constexpr const auto& GetVectorTypeReference(Number<I> i) const
157  {
158  static_assert(i % s_per_v == 0, "wrong!");
159 
160  constexpr auto i_v = i / s_per_v;
161 
162  return base::operator[](i_v);
163  }
164 
165  // Get write access to vector_type V
166  // i is offset of S, not V. i should be aligned to V
167  template <index_t I>
168  __host__ __device__ constexpr auto& GetVectorTypeReference(Number<I> i)
169  {
170  static_assert(i % s_per_v == 0, "wrong!");
171 
172  constexpr auto i_v = i / s_per_v;
173 
174  return base::operator()(i_v);
175  }
176 
177  __host__ __device__ void Clear()
178  {
179  constexpr index_t NumScalars = NumOfVector * ScalarPerVector;
180 
181  static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); });
182  }
183 };
184 
185 template <AddressSpaceEnum AddressSpace, typename T, index_t N>
186 __host__ __device__ constexpr auto make_static_buffer(Number<N>)
187 {
189 }
190 
191 template <AddressSpaceEnum AddressSpace, typename T, long_index_t N>
192 __host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
193 {
195 }
196 
197 } // namespace ck
Definition: ck.hpp:267
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
AddressSpaceEnum
Definition: amd_address_space.hpp:15
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
__host__ constexpr __device__ auto make_static_buffer(Number< N >)
Definition: static_buffer.hpp:186
int32_t index_t
Definition: ck.hpp:298
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: static_buffer.hpp:16
T type
Definition: static_buffer.hpp:17
__host__ constexpr __device__ StaticBuffer & operator=(const Tuple< Ys... > &y)
Definition: static_buffer.hpp:23
__host__ __device__ void Clear()
Definition: static_buffer.hpp:63
__host__ constexpr __device__ T & operator()(Number< I > i)
Definition: static_buffer.hpp:53
__host__ static constexpr __device__ bool IsDynamicBuffer()
Definition: static_buffer.hpp:42
__host__ static constexpr __device__ bool IsStaticBuffer()
Definition: static_buffer.hpp:40
__host__ constexpr __device__ StaticBuffer & operator=(const T &y)
Definition: static_buffer.hpp:31
StaticallyIndexedArray< T, N > base
Definition: static_buffer.hpp:18
__host__ constexpr __device__ StaticBuffer()
Definition: static_buffer.hpp:20
__host__ static constexpr __device__ AddressSpaceEnum GetAddressSpace()
Definition: static_buffer.hpp:38
__host__ __device__ void Set(T x)
Definition: static_buffer.hpp:58
__host__ constexpr __device__ const T & operator[](Number< I > i) const
Definition: static_buffer.hpp:46
Definition: static_buffer.hpp:75
__host__ constexpr __device__ StaticBufferTupleOfVector()
Definition: static_buffer.hpp:83
__host__ constexpr __device__ void SetAsType(Number< I > i, X x)
Definition: static_buffer.hpp:140
__host__ __device__ void Clear()
Definition: static_buffer.hpp:177
__host__ static constexpr __device__ bool IsDynamicBuffer()
Definition: static_buffer.hpp:89
__host__ constexpr __device__ const S & operator[](Number< I > i) const
Definition: static_buffer.hpp:96
StaticallyIndexedArray< vector_type< S, ScalarPerVector >, NumOfVector > base
Definition: static_buffer.hpp:77
__host__ constexpr __device__ auto GetAsType(Number< I > i) const
Definition: static_buffer.hpp:121
__host__ constexpr __device__ const auto & GetVectorTypeReference(Number< I > i) const
Definition: static_buffer.hpp:156
__host__ static constexpr __device__ AddressSpaceEnum GetAddressSpace()
Definition: static_buffer.hpp:85
__host__ constexpr __device__ auto & GetVectorTypeReference(Number< I > i)
Definition: static_buffer.hpp:168
__host__ static constexpr __device__ index_t Size()
Definition: static_buffer.hpp:91
typename vector_type< S, ScalarPerVector >::type V
Definition: static_buffer.hpp:76
__host__ static constexpr __device__ bool IsStaticBuffer()
Definition: static_buffer.hpp:87
static constexpr auto s_per_buf
Definition: static_buffer.hpp:81
__host__ constexpr __device__ S & operator()(Number< I > i)
Definition: static_buffer.hpp:107
static constexpr auto s_per_v
Definition: static_buffer.hpp:79
static constexpr auto num_of_v_
Definition: static_buffer.hpp:80
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: functional2.hpp:71
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10