/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/container/thread_buffer.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/container/thread_buffer.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/container/thread_buffer.hpp Source File
thread_buffer.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
9 
10 namespace ck_tile {
11 
12 #if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
13 template <typename T, index_t N>
15 
16 template <typename... Ts>
17 CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
18 {
19  return make_tuple(ts...);
20 }
21 #else
22 
23 #if 0
24 template <typename T, index_t N>
25 using thread_buffer = array<T, N>;
26 
27 template <typename... Ts>
28 CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
29 {
30  return make_array(ts...);
31 }
32 
33 #endif
34 
35 // clang-format off
36 template<typename T_, index_t N_>
37 struct thread_buffer {
38  using value_type = remove_cvref_t<T_>;
39  static constexpr index_t N = N_;
40 
41  value_type data[N];
42 
43  // TODO: this ctor can't ignore
44  CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
45  CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{} {
46  static_for<0, N, 1>{}(
47  [&](auto i) { data[i] = o; }
48  );
49  }
50 
51  CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
52  CK_TILE_HOST_DEVICE auto & get() {return data; }
53  CK_TILE_HOST_DEVICE const auto & get() const {return data; }
54  CK_TILE_HOST_DEVICE auto & get(index_t i) {return data[i]; }
55  CK_TILE_HOST_DEVICE const auto & get(index_t i) const {return data[i]; }
56  CK_TILE_HOST_DEVICE constexpr const auto& operator[](index_t i) const { return get(i); }
57  CK_TILE_HOST_DEVICE constexpr auto& operator[](index_t i) { return get(i); }
58  CK_TILE_HOST_DEVICE constexpr auto& operator()(index_t i) { return get(i); } // TODO: compatible
59  CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
60  CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
61  template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
62  template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
63  template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
64  template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
65 
66  template <typename X_,
68  CK_TILE_HOST_DEVICE constexpr auto _get_as() const
69  {
70  using X = remove_cvref_t<X_>;
71 
72  constexpr index_t kSPerX = vector_traits<X>::vector_size;
73  static_assert(N % kSPerX == 0);
74 
75  union {
76  thread_buffer<X_, N / kSPerX> data {};
77  // tuple_array<value_type, kSPerX> sub_data;
78  value_type sub_data[N];
79  } vx;
80  static_for<0, N, 1>{}(
81  [&](auto j) { vx.sub_data[j] = data[j]; });
82  return vx.data;
83  }
84 
85  template <typename X_,
86  index_t Is,
88  CK_TILE_HOST_DEVICE const constexpr remove_reference_t<X_> _get_as(number<Is> is) const
89  {
90  using X = remove_cvref_t<X_>;
91 
92  constexpr index_t kSPerX = vector_traits<X>::vector_size;
93 
94  union {
95  X_ data {};
96  tuple_array<value_type, kSPerX> sub_data;
97  } vx;
98  static_for<0, kSPerX, 1>{}(
99  [&](auto j) { vx.sub_data(j) = operator[]((is * number<sizeof(X_)/sizeof(value_type)>{}) + j); });
100  return vx.data;
101  }
102 
103 #if 0
104  template <typename X_,
105  index_t Is,
107  CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
108  {
109  using X = remove_cvref_t<X_>;
110 
111  constexpr index_t kSPerX = vector_traits<X>::vector_size;
112 
113  union {
114  X_ data;
115  tuple_array<value_type, kSPerX> sub_data;
116  } vx {x};
117 
118  static_for<0, kSPerX, 1>{}(
119  [&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
120  }
121 #endif
122 
123 
124 #define TB_COMMON_AS() \
125  static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
126  constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
127 
128  template<typename Tx>
129  CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS();
130  return reinterpret_cast<thread_buffer<Tx, vx>&>(data);}
131  template<typename Tx>
132  CK_TILE_HOST_DEVICE constexpr auto get_as() const {TB_COMMON_AS();
133  if constexpr(sizeof(value_type) <= 1 )
134  return _get_as<Tx>(); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
135  else
136  return reinterpret_cast<const thread_buffer<Tx, vx>&>(data);}
137  template<typename Tx, index_t I>
138  CK_TILE_HOST_DEVICE auto & get_as(number<I>) {TB_COMMON_AS();
139  return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(number<I>{});}
140  template<typename Tx, index_t I>
141  CK_TILE_HOST_DEVICE constexpr auto get_as(number<I>) const {TB_COMMON_AS();
142  if constexpr(sizeof(value_type) <= 1 )
143  return _get_as<Tx>(number<I>{}); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
144  else
145  return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(number<I>{});}
146 
147  template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
148  { TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(i) = x; }
149  template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
150  { TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(number<I>{}) = x; }
151 
152 #undef TB_COMMON_AS
153 };
154 // clang-format on
155 
156 template <typename, typename>
157 struct vector_traits;
158 
159 // specialization for array
160 template <typename T, index_t N>
161 struct vector_traits<thread_buffer<T, N>, std::enable_if_t<!std::is_class_v<T>>>
162 {
163  using scalar_type = T;
164  static constexpr index_t vector_size = N;
165 };
166 
167 template <typename T, index_t N>
168 struct vector_traits<thread_buffer<T, N>, std::enable_if_t<std::is_class_v<T>>>
169 {
170  using scalar_type = typename T::type;
171  static constexpr index_t vector_size = N;
172 };
173 
174 #endif
175 
176 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
tuple_array< T, N > thread_buffer
Definition: thread_buffer.hpp:14
typename impl::tuple_array_impl< T, N >::type tuple_array
Definition: tuple.hpp:28
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE details::return_type< D, Ts... > make_array(Ts &&... ts)
Definition: array.hpp:242
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE auto make_thread_buffer(Ts &&... ts)
Definition: thread_buffer.hpp:17
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: debug.hpp:67
static constexpr index_t vector_size
Definition: vector_type.hpp:92