include/ck_tile/core/container/thread_buffer.hpp Source File

include/ck_tile/core/container/thread_buffer.hpp Source File#

Composable Kernel: include/ck_tile/core/container/thread_buffer.hpp Source File
thread_buffer.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
9 
10 namespace ck_tile {
11 
12 #if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
13 template <typename T, index_t N>
15 
16 template <typename... Ts>
17 CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
18 {
19  return make_tuple(ts...);
20 }
21 #else
22 
23 #if 0
24 template <typename T, index_t N>
25 using thread_buffer = array<T, N>;
26 
27 template <typename... Ts>
28 CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
29 {
30  return make_array(ts...);
31 }
32 
33 #endif
34 
35 // clang-format off
36 template<typename T_, index_t N_>
37 struct thread_buffer {
38  using value_type = remove_cvref_t<T_>;
39  static constexpr index_t N = N_;
40 
41  value_type data[N];
42 
43  // TODO: this ctor can't ignore
44  CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
45  CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {}
46 
47  CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
48  CK_TILE_HOST_DEVICE auto & get() {return data; }
49  CK_TILE_HOST_DEVICE const auto & get() const {return data; }
50  CK_TILE_HOST_DEVICE auto & get(index_t i) {return data[i]; }
51  CK_TILE_HOST_DEVICE const auto & get(index_t i) const {return data[i]; }
52  CK_TILE_HOST_DEVICE constexpr const auto& operator[](index_t i) const { return get(i); }
53  CK_TILE_HOST_DEVICE constexpr auto& operator[](index_t i) { return get(i); }
54  CK_TILE_HOST_DEVICE constexpr auto& operator()(index_t i) { return get(i); } // TODO: compatible
55  CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
56  CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
57  template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
58  template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
59  template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
60  template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
61 
62  template <typename X_,
63  typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
64  CK_TILE_HOST_DEVICE constexpr auto _get_as() const
65  {
66  using X = remove_cvref_t<X_>;
67 
68  constexpr index_t kSPerX = vector_traits<X>::vector_size;
69  static_assert(N % kSPerX == 0);
70 
71  union {
72  thread_buffer<X_, N / kSPerX> data {};
73  // tuple_array<value_type, kSPerX> sub_data;
74  value_type sub_data[N];
75  } vx;
76  static_for<0, N, 1>{}(
77  [&](auto j) { vx.sub_data[j] = data[j]; });
78  return vx.data;
79  }
80 
81  template <typename X_,
82  index_t Is,
83  typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
84  CK_TILE_HOST_DEVICE const constexpr remove_reference_t<X_> _get_as(number<Is> is) const
85  {
86  using X = remove_cvref_t<X_>;
87 
88  constexpr index_t kSPerX = vector_traits<X>::vector_size;
89 
90  union {
91  X_ data {};
92  tuple_array<value_type, kSPerX> sub_data;
93  } vx;
94  static_for<0, kSPerX, 1>{}(
95  [&](auto j) { vx.sub_data(j) = operator[]((is * number<sizeof(X_)/sizeof(value_type)>{}) + j); });
96  return vx.data;
97  }
98 
99 #if 0
100  template <typename X_,
101  index_t Is,
102  typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
103  CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
104  {
105  using X = remove_cvref_t<X_>;
106 
107  constexpr index_t kSPerX = vector_traits<X>::vector_size;
108 
109  union {
110  X_ data;
111  tuple_array<value_type, kSPerX> sub_data;
112  } vx {x};
113 
114  static_for<0, kSPerX, 1>{}(
115  [&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
116  }
117 #endif
118 
119 
120 #define TB_COMMON_AS() \
121  static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
122  constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
123 
124  template<typename Tx>
125  CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS();
126  return reinterpret_cast<thread_buffer<Tx, vx>&>(data);}
127  template<typename Tx>
128  CK_TILE_HOST_DEVICE constexpr auto get_as() const {TB_COMMON_AS();
129  if constexpr(sizeof(value_type) <= 1 )
130  return _get_as<Tx>(); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
131  else
132  return reinterpret_cast<const thread_buffer<Tx, vx>&>(data);}
133  template<typename Tx, index_t I>
134  CK_TILE_HOST_DEVICE auto & get_as(number<I>) {TB_COMMON_AS();
135  return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(number<I>{});}
136  template<typename Tx, index_t I>
137  CK_TILE_HOST_DEVICE constexpr auto get_as(number<I>) const {TB_COMMON_AS();
138  if constexpr(sizeof(value_type) <= 1 )
139  return _get_as<Tx>(number<I>{}); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
140  else
141  return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(number<I>{});}
142 
143  template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
144  { TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(i) = x; }
145  template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
146  { TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(number<I>{}) = x; }
147 
148 #undef TB_COMMON_AS
149 };
150 // clang-format on
151 
152 template <typename>
153 struct vector_traits;
154 
155 // specialization for array
156 template <typename T, index_t N>
157 struct vector_traits<thread_buffer<T, N>>
158 {
159  using scalar_type = T;
160  static constexpr index_t vector_size = N;
161 };
162 
163 #endif
164 
165 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
tuple_array< T, N > thread_buffer
Definition: thread_buffer.hpp:14
typename impl::tuple_array_impl< T, N >::type tuple_array
Definition: tuple.hpp:28
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE details::return_type< D, Ts... > make_array(Ts &&... ts)
Definition: array.hpp:197
constant< v > number
Definition: integral_constant.hpp:33
constexpr CK_TILE_HOST_DEVICE auto make_thread_buffer(Ts &&... ts)
Definition: thread_buffer.hpp:17
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
static constexpr index_t vector_size
Definition: vector_type.hpp:62