include/ck_tile/core/tensor/tile_elementwise.hpp Source File

include/ck_tile/core/tensor/tile_elementwise.hpp Source File#

Composable Kernel: include/ck_tile/core/tensor/tile_elementwise.hpp Source File
tile_elementwise.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
15 
16 namespace ck_tile {
17 
18 // TODO: support tensors with different distribution
19 template <typename InOutElementFunc,
20  typename... InOutDstrTensors,
21  typename = std::enable_if_t<std::conjunction_v<
22  std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, null_tensor>>...>>>
23 CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func,
24  InOutDstrTensors&... inout_dstr_tensors)
25 {
26  // TODO: make sure all distributed tensors have same lengths and distribution
27  // static_assert(xxx);
28 
29  constexpr index_t thread_buffer_size =
30  __type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
31 
33  [&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); });
34 }
35 
36 template <typename InElementFunc,
37  typename... InTensor,
38  typename = std::enable_if_t<
39  std::conjunction_v<std::negation<std::is_same<InTensor, null_tensor>>...>>>
40 CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
41  const InTensor&... in_dstr_tensors)
42 {
43  using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
44 
45  // TODO: make sure all distributed tensors have same lengths and distribution
46  // static_assert(xxx);
47  constexpr auto in_tile_dstr = __type_pack_element<0, InTensor...>::get_tile_distribution();
48 
49  constexpr index_t thread_buffer_size =
50  __type_pack_element<0, InTensor...>::get_thread_buffer_size();
51 
52  auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
53 
55  out_dstr_tensor.get_thread_buffer()(i) =
56  in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
57  });
58 
59  return out_dstr_tensor;
60 }
61 
62 template <typename DstrTensors, typename T>
63 CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
64 {
66  [&value](auto& x) {
67  x = type_convert<typename DstrTensors::DataType, remove_cvref_t<T>>(value);
68  },
69  dstr_tensor);
70 }
71 
72 template <typename T>
74 {
75 }
76 
77 // TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
78 // sub-dword tensor...
79 template <typename DstrTensors, index_t v, bool skip_subdword_opt = false>
80 CK_TILE_DEVICE void
81 set_tile(DstrTensors& dstr_tensor, number<v>, bool_constant<skip_subdword_opt> = {})
82 {
83  using elem_type = typename DstrTensors::DataType;
84  constexpr index_t elem_size = sizeof(elem_type);
85 
86  constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
87 
88  // # bytes per write = 4
89  if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
90  {
91 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
92  auto& buffer = dstr_tensor.get_thread_buffer();
93 
94  static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) {
95  if constexpr(elem_size == 1)
96  {
97  // # elements per write = 4
98  constexpr auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
99 
100  buffer[i_write * 4 + 0] = values.x;
101  buffer[i_write * 4 + 1] = values.y;
102  buffer[i_write * 4 + 2] = values.z;
103  buffer[i_write * 4 + 3] = values.w;
104  }
105  else if constexpr(elem_size == 2)
106  {
107  // # elements per write = 2
108  constexpr auto values = ext_vector_t<elem_type, 2>{0, 0};
109 
110  buffer[i_write * 2 + 0] = values.x;
111  buffer[i_write * 2 + 1] = values.y;
112  }
113  else if constexpr(elem_size == 4)
114  {
115  // # elements per write = 1
116  constexpr elem_type value = 0;
117 
118  buffer[i_write] = value;
119  }
120  else
121  {
122  static_assert(false, "type not supported");
123  }
124  });
125 #else
126  using dvec_t = array<index_t, tensor_bytes / 4>;
127  auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
128  for(auto i = 0; i < tensor.size(); i++)
129  tensor.get(i) = v;
130 #endif
131  }
132  else
133  {
134  tile_elementwise_inout([](auto& x) { x = type_convert<elem_type, index_t>(v); },
135  dstr_tensor);
136  }
137 }
138 
139 template <index_t v>
141 {
142 }
143 
144 template <typename DstrTensors>
145 CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
146 {
147  set_tile(dstr_tensor, 0);
148 }
149 
150 namespace impl {
151 // TODO: this is ugly
152 template <typename OutDataType, typename InTensor>
153 CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
154 {
155 #if defined(__gfx94__)
156  // This API is designed to use the _pk_ serious of function
157  constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
158 
159  constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
160  static_assert(thread_buffer_size % 4 == 0);
161  constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4;
162 
163  auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
164 #pragma clang diagnostic push
165 #pragma clang diagnostic ignored "-Wuninitialized"
166  // __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and
167  // will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA
168  // so we prepare an uninitialized variable purposely, and turn off the warning
169  int dummy_old;
171  uint32_t x = __builtin_amdgcn_cvt_pk_fp8_f32(
172  in_dstr_tensors.get_thread_buffer()[number<4 * i + 0>{}],
173  in_dstr_tensors.get_thread_buffer()[number<4 * i + 1>{}],
174  dummy_old,
175  false); // false -> WORD0
176 
177  uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32(
178  in_dstr_tensors.get_thread_buffer()[number<4 * i + 2>{}],
179  in_dstr_tensors.get_thread_buffer()[number<4 * i + 3>{}],
180  dummy_old,
181  false); // false -> WORD0
182 
183  constexpr int32_t m0 = 0x05040100;
184  using vec_t = array<OutDataType, 4>;
185 
186  vec_t d = bit_cast<vec_t>(__builtin_amdgcn_perm(y, x, m0));
187  out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(number<i>{}, d);
188  });
189 #pragma clang diagnostic pop
190 
191  return out_dstr_tensor;
192 #else
193  // fallback
194  return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
195  in_dstr_tensors);
196 #endif
197 }
198 
199 template <typename OutDataType, typename InTensor>
200 CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
201 {
202 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
203  // This API is designed to use the _pk_ serious of function
204  constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
205 
206  constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
207  static_assert(thread_buffer_size % 2 == 0);
208  constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2;
209 
210  auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
211 
212  // TODO: this is rtz cvt, need be very careful
213  for(index_t i = 0; i < thread_buffer_size_pk; i++)
214  {
215  auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0],
216  in_dstr_tensors.get_thread_buffer()[2 * i + 1]);
217 
218  out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
219  out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
220  }
221 
222  return out_dstr_tensor;
223 #else
224  // fallback
225  return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
226  in_dstr_tensors);
227 #endif
228 }
229 
230 #if CK_TILE_USE_SUBDWORD_TILE_CAST
231 // this function assume either src or dst (or both) date type is under 1 dword
232 // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
233 template <typename OutDataType, typename InTensor>
234 CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
235 {
236  constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
237 
238  auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
239 
241  using o_type = remove_cvref_t<OutDataType>;
242  constexpr index_t i_elem_bytes = sizeof(i_type);
243  constexpr index_t o_elem_bytes = sizeof(o_type);
244  static_assert(i_elem_bytes < 4 || o_elem_bytes < 4);
245 
246  constexpr index_t bulk_size =
247  (i_elem_bytes >= o_elem_bytes) ? (4 / o_elem_bytes) : (4 / i_elem_bytes);
248  static_assert(bulk_size != 0);
249 
250  using o_bulk_type =
251  std::conditional_t<i_elem_bytes >= o_elem_bytes, float, array<o_type, bulk_size>>;
252 
253  constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
254 
255  constexpr index_t iters = thread_buffer_size / bulk_size;
256  constexpr index_t rems = thread_buffer_size % bulk_size;
257 
258  // cast the sequence per-bulk
259  static_for<0, iters, 1>{}([&](auto i) {
260  union bulk_wrapper
261  {
262  o_bulk_type bulk{};
263  o_type data[bulk_size];
264  } o_bulk;
265 
266  // TODO: should use below function, but somehow will result in spill (same as c-forloop)
267  static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib) {
268  o_bulk.data[ib.value] = static_cast<o_type>(
269  in_dstr_tensors.get_thread_buffer()
270  .template get_as<i_type>()[number<bulk_size * i.value + ib.value>{}]);
271  });
272 
273  // TODO: fixme, should use above!
274  // static_assert(sizeof(i_type) / sizeof(o_type) == 2);
275  // o_bulk.data[0] = static_cast<o_type>(
276  // in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
277  // o_bulk.data[1] = static_cast<o_type>(
278  // in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
279 
280  out_dstr_tensor.get_thread_buffer().template set_as<o_bulk_type>(i, o_bulk.bulk);
281  });
282 
283  static_for<0, rems, 1>{}([&](auto r) {
284  // TODO: introducing local scratch pad?
285  auto idx = number<iters * bulk_size + r>{};
286  out_dstr_tensor.get_thread_buffer().at(idx) =
287  static_cast<o_type>(in_dstr_tensors.get_thread_buffer().at(idx));
288  });
289 
290  return out_dstr_tensor;
291 }
292 #endif
293 } // namespace impl
294 
295 template <typename DstType, typename SrcTensor>
296 CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
297 {
298  if constexpr((std::is_same_v<DstType, fp8_t> ||
299  std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcTensor::DataType,
300  float> &&
301  (SrcTensor::get_thread_buffer_size() % 4 == 0))
302  {
303  return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
304  }
305 #if CK_TILE_USE_PK_FP16_TILE_CAST
306  else if constexpr(std::is_same_v<DstType, fp16_t> &&
307  std::is_same_v<typename SrcTensor::DataType, float> &&
308  (SrcTensor::get_thread_buffer_size() % 2 == 0))
309  {
310  return impl::cast_tile_pk_fp16_fp32<DstType, SrcTensor>(src_tensor);
311  }
312 #endif
313 #if CK_TILE_USE_SUBDWORD_TILE_CAST
314  else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
315  {
316  return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
317  }
318 #endif
319  else
320  return tile_elementwise_in(type_convert<DstType, typename SrcTensor::DataType>, src_tensor);
321 }
322 
323 // no-op function for null_tensor arguments
324 template <typename InOutElementFunc,
325  typename... MaybeNullTensor,
326  typename = std::enable_if_t<
327  std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
328 CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...)
329 {
330 }
331 
332 // no-op function for null_tensor arguments
333 template <typename InElementFunc,
334  typename... MaybeNullTensor,
335  typename = std::enable_if_t<
336  std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
337 CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...)
338 {
339  return null_tensor{};
340 }
341 
342 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor &in_dstr_tensors)
Definition: tile_elementwise.hpp:153
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor &in_dstr_tensors)
Definition: tile_elementwise.hpp:200
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:63
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constant< v > number
Definition: integral_constant.hpp:33
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition: tile_elementwise.hpp:296
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:145
constexpr bool is_same_v
Definition: type.hpp:283
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
Definition: array.hpp:24
Definition: integral_constant.hpp:13
Definition: null_tensor.hpp:9
Definition: functional.hpp:43