/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/dynamic_buffer.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/dynamic_buffer.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/dynamic_buffer.hpp Source File
dynamic_buffer.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/ck.hpp"
8 #include "enable_if.hpp"
10 #if __clang_major__ >= 20
11 #include "amd_buffer_addressing_builtins.hpp"
12 #else
14 #endif
15 #include "amd_transpose_load.hpp"
17 
18 namespace ck {
19 
20 // T may be scalar or vector
21 // X may be scalar or vector
22 // T and X have same scalar type
23 // X contains multiple T
24 template <AddressSpaceEnum BufferAddressSpace,
25  typename T,
26  typename ElementSpaceSize,
27  bool InvalidElementUseNumericalZeroValue,
29  typename IndexType = index_t>
31 {
32  using type = T;
33 
34  T* p_data_;
35  ElementSpaceSize element_space_size_;
37 
38  // XXX: PackedSize semantics for pk_i4_t is different from the other packed types.
39  // Objects of f4x2_pk_t and f6_pk_t are counted as 1 element, while
40  // objects of pk_i4_t are counted as 2 elements. Therefore, element_space_size_ for pk_i4_t must
41  // be divided by 2 to correctly represent the number of addressable elements.
42  static constexpr index_t PackedSize = []() {
43  if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
44  return 2;
45  else
46  return 1;
47  }();
48 
49  __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
50  : p_data_{p_data}, element_space_size_{element_space_size}
51  {
52  }
53 
54  __host__ __device__ constexpr DynamicBuffer(T* p_data,
55  ElementSpaceSize element_space_size,
56  T invalid_element_value)
57  : p_data_{p_data},
58  element_space_size_{element_space_size},
59  invalid_element_value_{invalid_element_value}
60  {
61  }
62 
63  __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace()
64  {
65  return BufferAddressSpace;
66  }
67 
68  __host__ __device__ constexpr const T& operator[](IndexType i) const { return p_data_[i]; }
69 
70  __host__ __device__ constexpr T& operator()(IndexType i) { return p_data_[i]; }
71 
72  template <typename X,
73  bool DoTranspose = false,
76  !is_native_type<X>(),
77  bool>::type = false>
78  __host__ __device__ constexpr auto Get(IndexType i, bool is_valid_element) const
79  {
80  // X contains multiple T
81  constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
82 
83  constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
84 
85  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
86  "wrong! X should contain multiple T");
87 
88 #if CK_USE_AMD_BUFFER_LOAD
89  bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t);
90 #else
91  bool constexpr use_amd_buffer_addressing = false;
92 #endif
93 
94  if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing &&
95  !DoTranspose)
96  {
97  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
98 
99  if constexpr(InvalidElementUseNumericalZeroValue)
100  {
101  return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
102  t_per_x,
103  coherence>(
104  p_data_, i, is_valid_element, element_space_size_ / PackedSize);
105  }
106  else
107  {
108  return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
109  t_per_x,
110  coherence>(
111  p_data_,
112  i,
113  is_valid_element,
116  }
117  }
118  else if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && DoTranspose)
119  {
120 #ifdef __gfx12__
121  return amd_global_load_transpose_to_vgpr(p_data_ + i);
122 #else
123  static_assert(!DoTranspose, "load-with-transpose only supported on gfx12+");
124 #endif
125  }
126  else
127  {
128  if(is_valid_element)
129  {
130 #if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
131  X tmp;
132 
133  __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
134 
135  return tmp;
136 #else
137  return *c_style_pointer_cast<const X*>(&p_data_[i]);
138 #endif
139  }
140  else
141  {
142  if constexpr(InvalidElementUseNumericalZeroValue)
143  {
144  return X{0};
145  }
146  else
147  {
148  return X{invalid_element_value_};
149  }
150  }
151  }
152  }
153 
154  template <InMemoryDataOperationEnum Op,
155  typename X,
158  bool>::type = false>
159  __host__ __device__ void Update(IndexType i, bool is_valid_element, const X& x)
160  {
161  if constexpr(Op == InMemoryDataOperationEnum::Set)
162  {
163  this->template Set<X>(i, is_valid_element, x);
164  }
165  else if constexpr(Op == InMemoryDataOperationEnum::AtomicAdd)
166  {
167  this->template AtomicAdd<X>(i, is_valid_element, x);
168  }
169  else if constexpr(Op == InMemoryDataOperationEnum::AtomicMax)
170  {
171  this->template AtomicMax<X>(i, is_valid_element, x);
172  }
173  else if constexpr(Op == InMemoryDataOperationEnum::Add)
174  {
175  auto tmp = this->template Get<X>(i, is_valid_element);
176  using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
177  // handle bfloat addition
178  if constexpr(is_same_v<scalar_t, bhalf_t>)
179  {
180  if constexpr(is_scalar_type<X>::value)
181  {
182  // Scalar type
183  auto result =
184  type_convert<X>(type_convert<float>(x) + type_convert<float>(tmp));
185  this->template Set<X>(i, is_valid_element, result);
186  }
187  else
188  {
189  // Vector type
190  constexpr auto vector_size = scalar_type<remove_cvref_t<X>>::vector_size;
191  const vector_type<scalar_t, vector_size> a_vector{tmp};
192  const vector_type<scalar_t, vector_size> b_vector{x};
193  static_for<0, vector_size, 1>{}([&](auto idx) {
194  auto result = type_convert<scalar_t>(
195  type_convert<float>(a_vector.template AsType<scalar_t>()[idx]) +
196  type_convert<float>(b_vector.template AsType<scalar_t>()[idx]));
197  this->template Set<scalar_t>(i + idx, is_valid_element, result);
198  });
199  }
200  }
201  else
202  {
203  this->template Set<X>(i, is_valid_element, x + tmp);
204  }
205  }
206  }
207 
208  template <typename DstBuffer, index_t NumElemsPerThread>
209  __host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf,
210  IndexType src_offset,
211  IndexType dst_offset,
212  bool is_valid_element) const
213  {
214  // Copy data from global to LDS memory using direct loads.
215  static_assert(GetAddressSpace() == AddressSpaceEnum::Global,
216  "Source data must come from a global memory buffer.");
217  static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
218  "Destination data must be stored in an LDS memory buffer.");
219 
220  amd_direct_load_global_to_lds<T, NumElemsPerThread>(p_data_,
221  src_offset,
222  dst_buf.p_data_,
223  dst_offset,
224  is_valid_element,
226  }
227 
228  template <typename X,
231  !is_native_type<X>(),
232  bool>::type = false>
233  __host__ __device__ void Set(IndexType i, bool is_valid_element, const X& x)
234  {
235  // X contains multiple T
236  constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
237 
238  constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
239 
240  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
241  "wrong! X should contain multiple T");
242 
243 #if CK_USE_AMD_BUFFER_LOAD
244  bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t);
245 #else
246  bool constexpr use_amd_buffer_addressing = false;
247 #endif
248 
249 #if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
250  bool constexpr workaround_int8_ds_write_issue = true;
251 #else
252  bool constexpr workaround_int8_ds_write_issue = false;
253 #endif
254 
255  if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
256  {
257  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
258 
259  amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
260  x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
261  }
262  else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
264  workaround_int8_ds_write_issue)
265  {
266  if(is_valid_element)
267  {
268  // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
269  // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
270  // ds_write_b128
271  // TODO: remove this after compiler fix
272  static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
288  "wrong! not implemented for this combination, please add "
289  "implementation");
290 
291  if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
293  {
294  // HACK: cast pointer of x is bad
295  // TODO: remove this after compiler fix
296  *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
297  *c_style_pointer_cast<const int8_t*>(&x);
298  }
299  else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
301  {
302  // HACK: cast pointer of x is bad
303  // TODO: remove this after compiler fix
304  *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
305  *c_style_pointer_cast<const int16_t*>(&x);
306  }
307  else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
309  {
310  // HACK: cast pointer of x is bad
311  // TODO: remove this after compiler fix
312  *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
313  *c_style_pointer_cast<const int32_t*>(&x);
314  }
315  else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
317  {
318  // HACK: cast pointer of x is bad
319  // TODO: remove this after compiler fix
320  *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
321  *c_style_pointer_cast<const int32x2_t*>(&x);
322  }
323  else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
325  {
326  // HACK: cast pointer of x is bad
327  // TODO: remove this after compiler fix
328  *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
329  *c_style_pointer_cast<const int32x4_t*>(&x);
330  }
331  else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
333  {
334  // HACK: cast pointer of x is bad
335  // TODO: remove this after compiler fix
336  *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
337  *c_style_pointer_cast<const int32_t*>(&x);
338  }
339  else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
341  {
342  // HACK: cast pointer of x is bad
343  // TODO: remove this after compiler fix
344  *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
345  *c_style_pointer_cast<const int32x2_t*>(&x);
346  }
347  else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
349  {
350  // HACK: cast pointer of x is bad
351  // TODO: remove this after compiler fix
352  *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
353  *c_style_pointer_cast<const int32x4_t*>(&x);
354  }
355  }
356  }
357  else
358  {
359  if(is_valid_element)
360  {
361 #if 0
362  X tmp = x;
363 
364  __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
365 #else
366  // if(i >= 2169041600)
367  *c_style_pointer_cast<X*>(&p_data_[i]) = x;
368 #endif
369  }
370  }
371  }
372 
373  template <typename X,
376  bool>::type = false>
377  __host__ __device__ void AtomicAdd(IndexType i, bool is_valid_element, const X& x)
378  {
379  using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
380 
381  // X contains multiple T
382  constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
383 
384  constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
385 
386  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
387  "wrong! X should contain multiple T");
388 
389  static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
390 
391 #if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
392  bool constexpr use_amd_buffer_addressing =
393  is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
394  is_same_v<remove_cvref_t<scalar_t>, float> ||
395  (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
396  (is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
397 #elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
398  bool constexpr use_amd_buffer_addressing =
399  sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, int32_t>;
400 #elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
401  bool constexpr use_amd_buffer_addressing =
402  sizeof(IndexType) <= sizeof(int32_t) &&
403  (is_same_v<remove_cvref_t<scalar_t>, float> ||
404  (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
405  (is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0));
406 #else
407  bool constexpr use_amd_buffer_addressing = false;
408 #endif
409 
410  if constexpr(use_amd_buffer_addressing)
411  {
412  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
413 
414  amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
415  x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
416  }
417  else
418  {
419  if(is_valid_element)
420  {
421  atomic_add<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
422  }
423  }
424  }
425 
426  template <typename X,
429  bool>::type = false>
430  __host__ __device__ void AtomicMax(IndexType i, bool is_valid_element, const X& x)
431  {
432  // X contains multiple T
433  constexpr IndexType scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
434 
435  constexpr IndexType scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
436 
437  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
438  "wrong! X should contain multiple T");
439 
440  static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
441 
442 #if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
443  using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
444  bool constexpr use_amd_buffer_addressing =
445  sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, double>;
446 #else
447  bool constexpr use_amd_buffer_addressing = false;
448 #endif
449 
450  if constexpr(use_amd_buffer_addressing)
451  {
452  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
453 
454  amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
455  x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
456  }
457  else if(is_valid_element)
458  {
459  atomic_max<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
460  }
461  }
462 
463  __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
464 
465  __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
466 };
467 
468 template <AddressSpaceEnum BufferAddressSpace,
470  typename T,
471  typename ElementSpaceSize>
472 __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
473 {
475  p, element_space_size};
476 }
477 
478 template <AddressSpaceEnum BufferAddressSpace,
480  typename T,
481  typename ElementSpaceSize>
482 __host__ __device__ constexpr auto make_long_dynamic_buffer(T* p,
483  ElementSpaceSize element_space_size)
484 {
486  p, element_space_size};
487 }
488 
489 template <
490  AddressSpaceEnum BufferAddressSpace,
492  typename T,
493  typename ElementSpaceSize,
494  typename X,
495  typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
496 __host__ __device__ constexpr auto
497 make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value)
498 {
500  p, element_space_size, invalid_element_value};
501 }
502 
503 } // namespace ck
Definition: ck.hpp:268
AmdBufferCoherenceEnum
Definition: amd_buffer_addressing.hpp:295
InMemoryDataOperationEnum
Definition: ck.hpp:277
typename vector_type< int8_t, 2 >::type int8x2_t
Definition: dtype_vector.hpp:2176
__host__ constexpr __device__ auto make_long_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:482
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2178
AddressSpaceEnum
Definition: amd_address_space.hpp:15
_Float16 half_t
Definition: data_type.hpp:31
ushort bhalf_t
Definition: data_type.hpp:30
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2179
constexpr bool is_same_v
Definition: type.hpp:283
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2177
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:472
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: dynamic_buffer.hpp:31
__host__ constexpr __device__ const T & operator[](IndexType i) const
Definition: dynamic_buffer.hpp:68
ElementSpaceSize element_space_size_
Definition: dynamic_buffer.hpp:35
__host__ constexpr __device__ DynamicBuffer(T *p_data, ElementSpaceSize element_space_size, T invalid_element_value)
Definition: dynamic_buffer.hpp:54
__host__ constexpr __device__ auto Get(IndexType i, bool is_valid_element) const
Definition: dynamic_buffer.hpp:78
__host__ constexpr __device__ T & operator()(IndexType i)
Definition: dynamic_buffer.hpp:70
T invalid_element_value_
Definition: dynamic_buffer.hpp:36
static constexpr index_t PackedSize
Definition: dynamic_buffer.hpp:42
__host__ __device__ void Update(IndexType i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:159
__host__ static constexpr __device__ bool IsDynamicBuffer()
Definition: dynamic_buffer.hpp:465
T * p_data_
Definition: dynamic_buffer.hpp:34
__host__ __device__ void AtomicAdd(IndexType i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:377
__host__ static constexpr __device__ AddressSpaceEnum GetAddressSpace()
Definition: dynamic_buffer.hpp:63
__host__ __device__ void DirectCopyToLds(DstBuffer &dst_buf, IndexType src_offset, IndexType dst_offset, bool is_valid_element) const
Definition: dynamic_buffer.hpp:209
__host__ __device__ void Set(IndexType i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:233
T type
Definition: dynamic_buffer.hpp:32
__host__ static constexpr __device__ bool IsStaticBuffer()
Definition: dynamic_buffer.hpp:463
__host__ constexpr __device__ DynamicBuffer(T *p_data, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:49
__host__ __device__ void AtomicMax(IndexType i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:430
Definition: type.hpp:177
Definition: data_type.hpp:218
Definition: data_type.hpp:187
Definition: data_type.hpp:39
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10