/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/dynamic_buffer.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/dynamic_buffer.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/dynamic_buffer.hpp Source File
dynamic_buffer.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/ck.hpp"
8 #include "enable_if.hpp"
10 #if __clang_major__ >= 20
11 #include "amd_buffer_addressing_builtins.hpp"
12 #else
14 #endif
16 
17 namespace ck {
18 
19 // T may be scalar or vector
20 // X may be scalar or vector
21 // T and X have same scalar type
22 // X contains multiple T
23 template <AddressSpaceEnum BufferAddressSpace,
24  typename T,
25  typename ElementSpaceSize,
26  bool InvalidElementUseNumericalZeroValue,
28  typename IndexType = index_t>
30 {
31  using type = T;
32 
33  T* p_data_;
34  ElementSpaceSize element_space_size_;
36 
37  // XXX: PackedSize semantics for pk_i4_t is different from the other packed types.
38  // Objects of f4x2_pk_t and f6_pk_t are counted as 1 element, while
39  // objects of pk_i4_t are counted as 2 elements. Therefore, element_space_size_ for pk_i4_t must
40  // be divided by 2 to correctly represent the number of addressable elements.
41  static constexpr index_t PackedSize = []() {
42  if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
43  return 2;
44  else
45  return 1;
46  }();
47 
48  __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
49  : p_data_{p_data}, element_space_size_{element_space_size}
50  {
51  }
52 
53  __host__ __device__ constexpr DynamicBuffer(T* p_data,
54  ElementSpaceSize element_space_size,
55  T invalid_element_value)
56  : p_data_{p_data},
57  element_space_size_{element_space_size},
58  invalid_element_value_{invalid_element_value}
59  {
60  }
61 
62  __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace()
63  {
64  return BufferAddressSpace;
65  }
66 
67  __host__ __device__ constexpr const T& operator[](IndexType i) const { return p_data_[i]; }
68 
69  __host__ __device__ constexpr T& operator()(IndexType i) { return p_data_[i]; }
70 
71  template <typename X,
74  !is_native_type<X>(),
75  bool>::type = false>
76  __host__ __device__ constexpr auto Get(IndexType i, bool is_valid_element) const
77  {
78  // X contains multiple T
79  constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
80 
81  constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
82 
83  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
84  "wrong! X should contain multiple T");
85 
86 #if CK_USE_AMD_BUFFER_LOAD
87  bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t);
88 #else
89  bool constexpr use_amd_buffer_addressing = false;
90 #endif
91 
92  if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
93  {
94  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
95 
96  if constexpr(InvalidElementUseNumericalZeroValue)
97  {
98  return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
99  t_per_x,
100  coherence>(
101  p_data_, i, is_valid_element, element_space_size_ / PackedSize);
102  }
103  else
104  {
105  return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
106  t_per_x,
107  coherence>(
108  p_data_,
109  i,
110  is_valid_element,
113  }
114  }
115  else
116  {
117  if(is_valid_element)
118  {
119 #if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
120  X tmp;
121 
122  __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
123 
124  return tmp;
125 #else
126  return *c_style_pointer_cast<const X*>(&p_data_[i]);
127 #endif
128  }
129  else
130  {
131  if constexpr(InvalidElementUseNumericalZeroValue)
132  {
133  return X{0};
134  }
135  else
136  {
137  return X{invalid_element_value_};
138  }
139  }
140  }
141  }
142 
143  template <InMemoryDataOperationEnum Op,
144  typename X,
147  bool>::type = false>
148  __host__ __device__ void Update(IndexType i, bool is_valid_element, const X& x)
149  {
150  if constexpr(Op == InMemoryDataOperationEnum::Set)
151  {
152  this->template Set<X>(i, is_valid_element, x);
153  }
154  else if constexpr(Op == InMemoryDataOperationEnum::AtomicAdd)
155  {
156  this->template AtomicAdd<X>(i, is_valid_element, x);
157  }
158  else if constexpr(Op == InMemoryDataOperationEnum::AtomicMax)
159  {
160  this->template AtomicMax<X>(i, is_valid_element, x);
161  }
162  else if constexpr(Op == InMemoryDataOperationEnum::Add)
163  {
164  auto tmp = this->template Get<X>(i, is_valid_element);
165  using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
166  // handle bfloat addition
167  if constexpr(is_same_v<scalar_t, bhalf_t>)
168  {
169  if constexpr(is_scalar_type<X>::value)
170  {
171  // Scalar type
172  auto result =
173  type_convert<X>(type_convert<float>(x) + type_convert<float>(tmp));
174  this->template Set<X>(i, is_valid_element, result);
175  }
176  else
177  {
178  // Vector type
179  constexpr auto vector_size = scalar_type<remove_cvref_t<X>>::vector_size;
180  const vector_type<scalar_t, vector_size> a_vector{tmp};
181  const vector_type<scalar_t, vector_size> b_vector{x};
182  static_for<0, vector_size, 1>{}([&](auto idx) {
183  auto result = type_convert<scalar_t>(
184  type_convert<float>(a_vector.template AsType<scalar_t>()[idx]) +
185  type_convert<float>(b_vector.template AsType<scalar_t>()[idx]));
186  this->template Set<scalar_t>(i + idx, is_valid_element, result);
187  });
188  }
189  }
190  else
191  {
192  this->template Set<X>(i, is_valid_element, x + tmp);
193  }
194  }
195  }
196 
197  template <typename DstBuffer, index_t NumElemsPerThread>
198  __host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf,
199  IndexType src_offset,
200  IndexType dst_offset,
201  bool is_valid_element) const
202  {
203  // Copy data from global to LDS memory using direct loads.
204  static_assert(GetAddressSpace() == AddressSpaceEnum::Global,
205  "Source data must come from a global memory buffer.");
206  static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
207  "Destination data must be stored in an LDS memory buffer.");
208 
209  amd_direct_load_global_to_lds<T, NumElemsPerThread>(p_data_,
210  src_offset,
211  dst_buf.p_data_,
212  dst_offset,
213  is_valid_element,
215  }
216 
217  template <typename X,
220  !is_native_type<X>(),
221  bool>::type = false>
222  __host__ __device__ void Set(IndexType i, bool is_valid_element, const X& x)
223  {
224  // X contains multiple T
225  constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
226 
227  constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
228 
229  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
230  "wrong! X should contain multiple T");
231 
232 #if CK_USE_AMD_BUFFER_LOAD
233  bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t);
234 #else
235  bool constexpr use_amd_buffer_addressing = false;
236 #endif
237 
238 #if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
239  bool constexpr workaround_int8_ds_write_issue = true;
240 #else
241  bool constexpr workaround_int8_ds_write_issue = false;
242 #endif
243 
244  if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
245  {
246  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
247 
248  amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
249  x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
250  }
251  else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
253  workaround_int8_ds_write_issue)
254  {
255  if(is_valid_element)
256  {
257  // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
258  // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
259  // ds_write_b128
260  // TODO: remove this after compiler fix
261  static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
277  "wrong! not implemented for this combination, please add "
278  "implementation");
279 
280  if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
282  {
283  // HACK: cast pointer of x is bad
284  // TODO: remove this after compiler fix
285  *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
286  *c_style_pointer_cast<const int8_t*>(&x);
287  }
288  else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
290  {
291  // HACK: cast pointer of x is bad
292  // TODO: remove this after compiler fix
293  *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
294  *c_style_pointer_cast<const int16_t*>(&x);
295  }
296  else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
298  {
299  // HACK: cast pointer of x is bad
300  // TODO: remove this after compiler fix
301  *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
302  *c_style_pointer_cast<const int32_t*>(&x);
303  }
304  else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
306  {
307  // HACK: cast pointer of x is bad
308  // TODO: remove this after compiler fix
309  *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
310  *c_style_pointer_cast<const int32x2_t*>(&x);
311  }
312  else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
314  {
315  // HACK: cast pointer of x is bad
316  // TODO: remove this after compiler fix
317  *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
318  *c_style_pointer_cast<const int32x4_t*>(&x);
319  }
320  else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
322  {
323  // HACK: cast pointer of x is bad
324  // TODO: remove this after compiler fix
325  *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
326  *c_style_pointer_cast<const int32_t*>(&x);
327  }
328  else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
330  {
331  // HACK: cast pointer of x is bad
332  // TODO: remove this after compiler fix
333  *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
334  *c_style_pointer_cast<const int32x2_t*>(&x);
335  }
336  else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
338  {
339  // HACK: cast pointer of x is bad
340  // TODO: remove this after compiler fix
341  *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
342  *c_style_pointer_cast<const int32x4_t*>(&x);
343  }
344  }
345  }
346  else
347  {
348  if(is_valid_element)
349  {
350 #if 0
351  X tmp = x;
352 
353  __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
354 #else
355  // if(i >= 2169041600)
356  *c_style_pointer_cast<X*>(&p_data_[i]) = x;
357 #endif
358  }
359  }
360  }
361 
362  template <typename X,
365  bool>::type = false>
366  __host__ __device__ void AtomicAdd(IndexType i, bool is_valid_element, const X& x)
367  {
368  using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
369 
370  // X contains multiple T
371  constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
372 
373  constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
374 
375  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
376  "wrong! X should contain multiple T");
377 
378  static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
379 
380 #if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
381  bool constexpr use_amd_buffer_addressing =
382  is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
383  is_same_v<remove_cvref_t<scalar_t>, float> ||
384  (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
385  (is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
386 #elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
387  bool constexpr use_amd_buffer_addressing =
388  sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, int32_t>;
389 #elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
390  bool constexpr use_amd_buffer_addressing =
391  sizeof(IndexType) <= sizeof(int32_t) &&
392  (is_same_v<remove_cvref_t<scalar_t>, float> ||
393  (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
394  (is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0));
395 #else
396  bool constexpr use_amd_buffer_addressing = false;
397 #endif
398 
399  if constexpr(use_amd_buffer_addressing)
400  {
401  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
402 
403  amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
404  x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
405  }
406  else
407  {
408  if(is_valid_element)
409  {
410  atomic_add<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
411  }
412  }
413  }
414 
415  template <typename X,
418  bool>::type = false>
419  __host__ __device__ void AtomicMax(IndexType i, bool is_valid_element, const X& x)
420  {
421  // X contains multiple T
422  constexpr IndexType scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
423 
424  constexpr IndexType scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
425 
426  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
427  "wrong! X should contain multiple T");
428 
429  static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
430 
431 #if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
432  using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
433  bool constexpr use_amd_buffer_addressing =
434  sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, double>;
435 #else
436  bool constexpr use_amd_buffer_addressing = false;
437 #endif
438 
439  if constexpr(use_amd_buffer_addressing)
440  {
441  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
442 
443  amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
444  x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
445  }
446  else if(is_valid_element)
447  {
448  atomic_max<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
449  }
450  }
451 
452  __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
453 
454  __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
455 };
456 
457 template <AddressSpaceEnum BufferAddressSpace,
459  typename T,
460  typename ElementSpaceSize>
461 __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
462 {
464  p, element_space_size};
465 }
466 
467 template <AddressSpaceEnum BufferAddressSpace,
469  typename T,
470  typename ElementSpaceSize>
471 __host__ __device__ constexpr auto make_long_dynamic_buffer(T* p,
472  ElementSpaceSize element_space_size)
473 {
475  p, element_space_size};
476 }
477 
478 template <
479  AddressSpaceEnum BufferAddressSpace,
481  typename T,
482  typename ElementSpaceSize,
483  typename X,
484  typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
485 __host__ __device__ constexpr auto
486 make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value)
487 {
489  p, element_space_size, invalid_element_value};
490 }
491 
492 } // namespace ck
Definition: ck.hpp:267
AmdBufferCoherenceEnum
Definition: amd_buffer_addressing.hpp:295
InMemoryDataOperationEnum
Definition: ck.hpp:276
typename vector_type< int8_t, 2 >::type int8x2_t
Definition: dtype_vector.hpp:2162
__host__ constexpr __device__ auto make_long_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:471
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2164
AddressSpaceEnum
Definition: amd_address_space.hpp:15
_Float16 half_t
Definition: data_type.hpp:30
ushort bhalf_t
Definition: data_type.hpp:29
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2165
constexpr bool is_same_v
Definition: type.hpp:283
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2163
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:461
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: dynamic_buffer.hpp:30
__host__ constexpr __device__ auto Get(IndexType i, bool is_valid_element) const
Definition: dynamic_buffer.hpp:76
__host__ constexpr __device__ const T & operator[](IndexType i) const
Definition: dynamic_buffer.hpp:67
ElementSpaceSize element_space_size_
Definition: dynamic_buffer.hpp:34
__host__ constexpr __device__ DynamicBuffer(T *p_data, ElementSpaceSize element_space_size, T invalid_element_value)
Definition: dynamic_buffer.hpp:53
__host__ constexpr __device__ T & operator()(IndexType i)
Definition: dynamic_buffer.hpp:69
T invalid_element_value_
Definition: dynamic_buffer.hpp:35
static constexpr index_t PackedSize
Definition: dynamic_buffer.hpp:41
__host__ __device__ void Update(IndexType i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:148
__host__ static constexpr __device__ bool IsDynamicBuffer()
Definition: dynamic_buffer.hpp:454
T * p_data_
Definition: dynamic_buffer.hpp:33
__host__ __device__ void AtomicAdd(IndexType i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:366
__host__ static constexpr __device__ AddressSpaceEnum GetAddressSpace()
Definition: dynamic_buffer.hpp:62
__host__ __device__ void DirectCopyToLds(DstBuffer &dst_buf, IndexType src_offset, IndexType dst_offset, bool is_valid_element) const
Definition: dynamic_buffer.hpp:198
__host__ __device__ void Set(IndexType i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:222
T type
Definition: dynamic_buffer.hpp:31
__host__ static constexpr __device__ bool IsStaticBuffer()
Definition: dynamic_buffer.hpp:452
__host__ constexpr __device__ DynamicBuffer(T *p_data, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:48
__host__ __device__ void AtomicMax(IndexType i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:419
Definition: type.hpp:177
Definition: data_type.hpp:217
Definition: data_type.hpp:186
Definition: data_type.hpp:38
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10