/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/buffer_view.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/buffer_view.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/buffer_view.hpp Source File
buffer_view.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
19 
20 namespace ck_tile {
21 
22 // T may be scalar or vector
23 // X may be scalar or vector
24 // T and X have same scalar type
25 // X contains multiple T
26 // FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
27 // transforms of tensor_view/Tensor
28 // FIXME: amd_buffer_coherence_enum is only meaningful for buffer addressing. Need to split
29 // buffer_view definition for different memory address space (Global/GenericLds/Vgpr)
30 template <address_space_enum BufferAddressSpace,
31  typename T,
32  typename BufferSizeType,
33  bool InvalidElementUseNumericalZeroValue,
35 struct buffer_view;
36 
37 // Address Space: generic
38 // T may be scalar or vector
39 // X may be scalar or vector
40 // T and X have same scalar type
41 // X contains multiple T
42 // FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
43 // transforms of tensor_view/Tensor
44 template <typename T, typename BufferSizeType, bool InvalidElementUseNumericalZeroValue>
45 struct buffer_view<address_space_enum::generic,
46  T,
47  BufferSizeType,
48  InvalidElementUseNumericalZeroValue,
50 {
51  using type = T;
52 
53  T* p_data_ = nullptr;
54  BufferSizeType buffer_size_;
55  remove_cvref_t<T> invalid_element_value_ = T{0};
56 
58  : p_data_{}, buffer_size_{}, invalid_element_value_{}
59  {
60  }
61 
62  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
63  : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
64  {
65  }
66 
67  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
68  BufferSizeType buffer_size,
69  T invalid_element_value)
70  : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
71  {
72  }
73 
75 
76  CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
77  {
78  return address_space_enum::generic;
79  }
80 
81  // i is offset of T
82  // FIXME: doesn't do is_valid check
83  CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
84 
85  // i is offset of T
86  // FIXME: doesn't do is_valid check
87  CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
88 
89  // i is offset of T, not X. i should be aligned to X
90  template <typename X,
91  bool oob_conditional_check = true,
92  typename std::enable_if<
93  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
94  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
95  bool>::type = false>
96  CK_TILE_DEVICE constexpr auto get(index_t i,
97  index_t linear_offset,
98  bool is_valid_element,
100  {
101  // X contains multiple T
102  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
103 
104  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
105 
106  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
107  "wrong! X should contain multiple T");
108 
109  if(is_valid_element)
110  {
111 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
112  X tmp;
113 
114  __builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
115 
116  return tmp;
117 #else
118  return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
119 #endif
120  }
121  else
122  {
123  if constexpr(InvalidElementUseNumericalZeroValue)
124  {
125  return X{numeric<remove_cvref_t<T>>::zero()};
126  }
127  else
128  {
129  return X{invalid_element_value_};
130  }
131  }
132  }
133 
134  /*
135  In the generic address space, we do not support the transpose instruction in the buffer view.
136  Will report compilation error when developer wants to use it.
137  */
138  template <typename X,
139  bool oob_conditional_check = true,
140  typename std::enable_if<
141  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
142  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
143  bool>::type = false>
145  index_t linear_offset,
146  bool is_valid_element,
148  {
149  static_assert(false, "Error: transpose load not supported in global memory space.");
150  ignore = i;
151  ignore = linear_offset;
152  ignore = is_valid_element;
153  return;
154  }
155 
156  // i is offset of T, not X. i should be aligned to X
157  template <memory_operation_enum Op,
158  typename X,
159  typename std::enable_if<
160  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
161  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
162  bool>::type = false>
163  CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
164  {
165  if constexpr(Op == memory_operation_enum::set)
166  {
167  this->template set<X>(i, linear_offset, is_valid_element, x);
168  }
169  // FIXME: remove memory_operation_enum::add
170  else if constexpr(Op == memory_operation_enum::add)
171  {
172  auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
173  this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
174  }
175  }
176 
177  // i is offset of T, not X. i should be aligned to X
178  template <typename X,
179  typename std::enable_if<
180  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
181  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
182  bool>::type = false>
183  CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
184  {
185  // X contains multiple T
186  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
187 
188  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
189 
190  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
191  "wrong! X should contain multiple T");
192 
193  if(is_valid_element)
194  {
195 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
196  X tmp = x;
197 
198  __builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
199 #else
200  *c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
201 #endif
202  }
203  }
204 
205  // FIXME: remove
206  CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
207 
208  // FIXME: remove
209  CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
210 };
211 
212 // Address Space: Global
213 // T may be scalar or vector
214 // X may be scalar or vector
215 // T and X have same scalar type
216 // X contains multiple T
217 // FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
218 // transforms of tensor_view/Tensor
219 template <typename T,
220  typename BufferSizeType,
221  bool InvalidElementUseNumericalZeroValue,
222  amd_buffer_coherence_enum Coherence>
223 struct buffer_view<address_space_enum::global,
224  T,
225  BufferSizeType,
226  InvalidElementUseNumericalZeroValue,
227  Coherence>
228 {
229  using type = T;
230 
231  T* p_data_ = nullptr;
232  BufferSizeType buffer_size_;
234  remove_cvref_t<T> invalid_element_value_ = T{0};
235 
236  static constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
237 
239  : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{}
240  {
241  }
242 
243  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
244  : p_data_{p_data},
245  buffer_size_{buffer_size / PackedSize},
246  cached_buf_res_{0},
247  invalid_element_value_{}
248  {
249  }
250 
251  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
252  BufferSizeType buffer_size,
253  T invalid_element_value)
254  : p_data_{p_data},
255  buffer_size_{buffer_size / PackedSize},
256  cached_buf_res_{0},
257  invalid_element_value_{invalid_element_value}
258  {
259  }
260 
261  // this is non constexpr intentially (will call some intrinsic internally)
262  // Must call for buffers that need *_raw load/store
264  {
265  cached_buf_res_ = make_wave_buffer_resource(p_data_, (buffer_size_) * sizeof(type));
266  }
267 
268  CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
269  {
270  return address_space_enum::global;
271  }
272 
273  // i is offset of T
274  // FIXME: doesn't do is_valid check
275  CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
276 
277  // i is offset of T
278  // FIXME: doesn't do is_valid check
279  CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
280 
281  // i is offset of T, not X. i should be aligned to X
282  template <typename X,
283  bool oob_conditional_check = true,
284  typename std::enable_if<
285  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
286  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
287  bool>::type = false>
288  CK_TILE_DEVICE constexpr auto get(index_t i,
289  index_t linear_offset,
290  bool is_valid_element,
292  {
293  // X contains multiple T
294  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
295 
296  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
297 
298  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
299  "wrong! X should contain multiple T");
300 
301 #if CK_TILE_USE_AMD_BUFFER_LOAD
302  bool constexpr use_amd_buffer_addressing = true;
303 #else
304  bool constexpr use_amd_buffer_addressing = false;
305 #endif
306 
307  if constexpr(use_amd_buffer_addressing)
308  {
309  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
310 
311  if constexpr(InvalidElementUseNumericalZeroValue)
312  {
313  return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
314  t_per_x,
315  Coherence,
316  oob_conditional_check>(
317  p_data_, i + linear_offset, is_valid_element, buffer_size_);
318  }
319  else
320  {
322  remove_cvref_t<T>,
323  t_per_x,
324  Coherence,
325  oob_conditional_check>(p_data_,
326  i + linear_offset,
327  is_valid_element,
328  buffer_size_,
329  invalid_element_value_);
330  }
331  }
332  else
333  {
334  if(is_valid_element)
335  {
336 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
337  X tmp;
338 
339  __builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
340 
341  return tmp;
342 #else
343  return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
344 #endif
345  }
346  else
347  {
348  if constexpr(InvalidElementUseNumericalZeroValue)
349  {
350  return X{numeric<remove_cvref_t<T>>::zero()};
351  }
352  else
353  {
354  return X{invalid_element_value_};
355  }
356  }
357  }
358  }
359 
360  /*
361  In the global memory address space, we do not support the transpose instruction in the buffer
362  view. Will report compilation error when developer wants to use it.
363  */
364  template <typename X,
365  bool oob_conditional_check = true,
366  typename std::enable_if<
367  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
368  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
369  bool>::type = false>
371  index_t linear_offset,
372  bool is_valid_element,
374  {
375  static_assert(false, "Error: transpose load not supported in global memory space.");
376  ignore = i;
377  ignore = linear_offset;
378  ignore = is_valid_element;
379  return;
380  }
381 
382  // i is offset of T, not X. i should be aligned to X
383  template <typename X,
384  bool oob_conditional_check = true,
385  bool pre_nop = false,
386  typename std::enable_if<
387  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
388  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
389  bool>::type = false>
391  index_t v_offset,
392  index_t i_offset,
393  bool is_valid_element,
394  bool_constant<pre_nop> = {}) const
395  {
396  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
397 
398  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
399 
400  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
401  "wrong! X should contain multiple T");
402 
403  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
404 
405  amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
406  dst, cached_buf_res_, v_offset, i_offset, is_valid_element, bool_constant<pre_nop>{});
407  }
408 
409  // i is offset of T, not X. i should be aligned to X
410  template <typename X,
411  bool oob_conditional_check = true,
412  typename std::enable_if<
413  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
414  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
415  bool>::type = false>
417  index_t i,
418  index_t linear_offset,
419  bool is_valid_element,
421  {
422  // X is vector of T
423  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
424  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
425 
426  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
427  "wrong! X should contain multiple T");
428 
429  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
430  const int32x4_t src_wave_buffer_resource =
431  make_wave_buffer_resource(p_data_, (buffer_size_) * sizeof(type));
432 
433  amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
434  smem,
435  src_wave_buffer_resource,
436  i,
437  linear_offset,
438  is_valid_element,
440  }
441 
442  // i is offset of T, not X. i should be aligned to X
443  template <typename X,
444  bool pre_nop = false,
445  typename std::enable_if<
446  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
447  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
448  bool>::type = false>
450  index_t i,
451  index_t linear_offset,
452  bool /*is_valid_element*/,
453  bool_constant<pre_nop> = {}) const
454  {
455  // X is vector of T
456  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
457  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
458 
459  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
460  "wrong! X should contain multiple T");
461 
462  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
463 
464  amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
465  smem, cached_buf_res_, i, linear_offset, bool_constant<pre_nop>{});
466  }
467 
468  // i is offset of T, not X. i should be aligned to X
469  template <memory_operation_enum Op,
470  typename X,
471  bool oob_conditional_check = true,
472  typename std::enable_if<
473  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
474  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
475  bool>::type = false>
477  index_t linear_offset,
478  bool is_valid_element,
479  const X& x,
481  {
482  if constexpr(Op == memory_operation_enum::set)
483  {
484  this->template set<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
485  }
486  else if constexpr(Op == memory_operation_enum::atomic_add)
487  {
488  this->template atomic_add<X, oob_conditional_check>(
489  i, linear_offset, is_valid_element, x);
490  }
491  else if constexpr(Op == memory_operation_enum::atomic_max)
492  {
493  this->template atomic_max<X, oob_conditional_check>(
494  i, linear_offset, is_valid_element, x);
495  }
496  // FIXME: remove memory_operation_enum::add
497  else if constexpr(Op == memory_operation_enum::add)
498  {
499  auto tmp =
500  this->template get<X, oob_conditional_check>(i, linear_offset, is_valid_element);
501  this->template set<X, oob_conditional_check>(
502  i, linear_offset, is_valid_element, x + tmp);
503  // tmp += x;
504  // this->template set<X>(i, is_valid_element, tmp);
505  }
506  }
507 
508  // i is offset of T, not X. i should be aligned to X
509  template <memory_operation_enum Op,
510  typename X,
511  bool oob_conditional_check = true,
512  bool pre_nop = false,
513  typename std::enable_if<
514  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
515  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
516  bool>::type = false>
518  index_t linear_offset,
519  bool is_valid_element,
520  const X& x,
523  {
524  if constexpr(Op == memory_operation_enum::set)
525  {
526  this->template set_raw<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
527  }
528  else if constexpr(Op == memory_operation_enum::atomic_add)
529  {
530  this->template atomic_add_raw<X, oob_conditional_check, pre_nop>(
531  i, linear_offset, is_valid_element, x);
532  }
533  else if constexpr(Op == memory_operation_enum::atomic_max)
534  {
535  // this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
536  }
537  }
538 
539  // i is offset of T, not X. i should be aligned to X
540  template <typename X,
541  bool oob_conditional_check = true,
542  typename std::enable_if<
543  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
544  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
545  bool>::type = false>
546  CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
547  {
548  // X contains multiple T
549  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
550 
551  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
552 
553  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
554  "wrong! X should contain multiple T");
555 
556 #if CK_TILE_USE_AMD_BUFFER_STORE
557  bool constexpr use_amd_buffer_addressing = true;
558 #else
559  bool constexpr use_amd_buffer_addressing = false;
560 #endif
561 
562  if constexpr(use_amd_buffer_addressing)
563  {
564  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
565 
566  amd_buffer_store<remove_cvref_t<T>, t_per_x, Coherence>(
567  x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
568  }
569  else
570  {
571  if(is_valid_element)
572  {
573 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
574  X tmp = x;
575 
576  __builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
577 #else
578  *c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
579 #endif
580  }
581  }
582  }
583 
584  // i is offset of T, not X. i should be aligned to X
585  template <typename X,
586  bool oob_conditional_check = true,
587  typename std::enable_if<
588  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
589  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
590  bool>::type = false>
591  CK_TILE_DEVICE void set_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
592  {
593  // X contains multiple T
594  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
595 
596  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
597 
598  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
599  "wrong! X should contain multiple T");
600 
601  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
602  amd_buffer_store_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
603  x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
604  }
605 
606  template <typename X,
607  bool oob_conditional_check = true,
608  typename std::enable_if<
609  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
610  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
611  bool>::type = false>
612  CK_TILE_DEVICE void
613  atomic_add(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
614  {
615  using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
616 
617  // X contains multiple T
618  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
619 
620  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
621 
622  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
623  "wrong! X should contain multiple T");
624 
625  static_assert(get_address_space() == address_space_enum::global, "only support global mem");
626 
627 #if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
628  bool constexpr use_amd_buffer_addressing =
629  std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
630  std::is_same_v<remove_cvref_t<scalar_t>, float> ||
631  (std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
632 #if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
633  ||
634  (std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
635 #endif
636  ;
637 #elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
638  bool constexpr use_amd_buffer_addressing =
639  std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
640 #elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
641  bool constexpr use_amd_buffer_addressing =
642  std::is_same_v<remove_cvref_t<scalar_t>, float> ||
643  (std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
644 #if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
645  ||
646  (std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
647 #endif
648  ;
649 #else
650  bool constexpr use_amd_buffer_addressing = false;
651 #endif
652 
653  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
654 
655  if constexpr(use_amd_buffer_addressing)
656  {
657  amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
658  x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
659  }
660  else
661  {
662  if(is_valid_element)
663  {
664  atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
665  }
666  }
667  }
668 
669  template <typename X,
670  bool oob_conditional_check = true,
671  bool pre_nop = true,
672  typename std::enable_if<
673  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
674  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
675  bool>::type = false>
676  CK_TILE_DEVICE void
677  atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
678  {
679  // using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
680 
681  // X contains multiple T
682  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
683 
684  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
685 
686  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
687  "wrong! X should contain multiple T");
688 
689  static_assert(get_address_space() == address_space_enum::global, "only support global mem");
690 
691  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
692 
693  amd_buffer_atomic_add_raw<remove_cvref_t<T>,
694  t_per_x,
695  Coherence,
696  oob_conditional_check,
697  pre_nop>(
698  x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
699  }
700 
701  template <typename X,
702  bool oob_conditional_check = true,
703  typename std::enable_if<
704  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
705  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
706  bool>::type = false>
707  CK_TILE_DEVICE void
708  atomic_max(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
709  {
710  // X contains multiple T
711  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
712 
713  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
714 
715  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
716  "wrong! X should contain multiple T");
717 
718  static_assert(get_address_space() == address_space_enum::global, "only support global mem");
719 
720 #if CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
721  using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
722  bool constexpr use_amd_buffer_addressing = std::is_same_v<remove_cvref_t<scalar_t>, double>;
723 #else
724  bool constexpr use_amd_buffer_addressing = false;
725 #endif
726 
727  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
728 
729  if constexpr(use_amd_buffer_addressing)
730  {
731  amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
732  x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
733  }
734  else if(is_valid_element)
735  {
736  atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
737  }
738  }
739 
740  // FIXME: remove
741  CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
742 
743  // FIXME: remove
744  CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
745 };
746 
747 // Address Space: LDS
748 // T may be scalar or vector
749 // X may be scalar or vector
750 // T and X have same scalar type
751 // X contains multiple T
752 // FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
753 // transforms of tensor_view/Tensor
754 template <typename T, typename BufferSizeType, bool InvalidElementUseNumericalZeroValue>
755 struct buffer_view<address_space_enum::lds,
756  T,
757  BufferSizeType,
758  InvalidElementUseNumericalZeroValue,
760 {
761  using type = T;
762 
763  T* p_data_ = nullptr;
764  BufferSizeType buffer_size_;
765  remove_cvref_t<T> invalid_element_value_ = T{0};
766 
768  : p_data_{}, buffer_size_{}, invalid_element_value_{}
769  {
770  }
771 
772  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
773  : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
774  {
775  }
776 
777  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
778  BufferSizeType buffer_size,
779  T invalid_element_value)
780  : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
781  {
782  }
783 
785 
786  CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
787  {
788  return address_space_enum::lds;
789  }
790 
791  // i is offset of T
792  // FIXME: doesn't do is_valid check
793  CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
794 
795  // i is offset of T
796  // FIXME: doesn't do is_valid check
797  CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
798 
799  // i is offset of T, not X. i should be aligned to X
800  template <typename X,
801  bool oob_conditional_check = true,
802  typename std::enable_if<
803  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
804  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
805  bool>::type = false>
806  CK_TILE_DEVICE constexpr auto get(index_t i,
807  index_t linear_offset,
808  bool is_valid_element,
810  {
811  // X contains multiple T
812  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
813 
814  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
815 
816  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
817  "wrong! X should contain multiple T");
818 
819  if(is_valid_element)
820  {
821 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
822  X tmp;
823 
824  __builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
825 
826  return tmp;
827 #else
828  using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
829  scalar_per_t_vector * scalar_per_x_vector>;
830  // using buf_t = ushort __attribute__((ext_vector_type(8)));
831  auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i + linear_offset]);
832  return bit_cast<X>(rtn);
833 #endif
834  }
835  else
836  {
837  if constexpr(InvalidElementUseNumericalZeroValue)
838  {
839  return X{numeric<remove_cvref_t<T>>::zero()};
840  }
841  else
842  {
843  return X{invalid_element_value_};
844  }
845  }
846  }
847 
848  // i is offset of T, not X. i should be aligned to X
849  template <typename X,
850  bool oob_conditional_check = true,
851  bool pre_nop = false,
852  typename std::enable_if<
853  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
854  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
855  bool>::type = false>
857  index_t v_offset,
858  index_t i_offset,
859  bool /*is_valid_element*/,
860  bool_constant<pre_nop> = {}) const
861  {
862  smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
863  }
864 
865  template <typename X,
866  typename std::enable_if<
867  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
868  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
869  bool>::type = false>
870  CK_TILE_DEVICE constexpr auto transpose_get([[maybe_unused]] index_t i,
871  [[maybe_unused]] index_t linear_offset,
872  bool is_valid_element) const
873  {
874  // X contains multiple T
875  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
876 
877  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
878 
879  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
880  "wrong! X should contain multiple T");
881 
882  if(is_valid_element)
883  {
884 #if defined(__gfx950__)
885  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
886  return amd_transpose_load_to_vgpr<remove_cvref_t<T>, t_per_x>(p_data_ + i +
887  linear_offset);
888 #else
889  return X{numeric<remove_cvref_t<T>>::zero()};
890 #endif
891  }
892  else
893  {
894  if constexpr(InvalidElementUseNumericalZeroValue)
895  {
896  return X{numeric<remove_cvref_t<T>>::zero()};
897  }
898  else
899  {
900  return X{invalid_element_value_};
901  }
902  }
903  }
904 
905  // i is offset of T, not X. i should be aligned to X
906  template <memory_operation_enum Op,
907  typename X,
908  typename std::enable_if<
909  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
910  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
911  bool>::type = false>
912  CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
913  {
914  if constexpr(Op == memory_operation_enum::set)
915  {
916  this->template set<X>(i, linear_offset, is_valid_element, x);
917  }
918  // FIXME: remove memory_operation_enum::add
919  else if constexpr(Op == memory_operation_enum::add)
920  {
921  auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
922  this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
923  }
924  }
925 
926  // i is offset of T, not X. i should be aligned to X
927  template <typename X,
928  typename std::enable_if<
929  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
930  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
931  bool>::type = false>
932  CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
933  {
934  // X contains multiple T
935  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
936 
937  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
938 
939  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
940  "wrong! X should contain multiple T");
941 
942 #if CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
943  bool constexpr workaround_int8_ds_write_issue = true;
944 #else
945  bool constexpr workaround_int8_ds_write_issue = false;
946 #endif
947 
948  i += linear_offset; // simplicity
949  if constexpr(std::is_same_v<typename vector_traits<remove_cvref_t<T>>::scalar_type,
950  int8_t> &&
951  workaround_int8_ds_write_issue)
952  {
953  if(is_valid_element)
954  {
955  // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
956  // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
957  // ds_write_b128
958  // TODO: remove this after compiler fix
959  // clang-format off
960  static_assert(
969  // int8 on thread buffer
975  // ext_vector_type for pk_int4 must use int8_t as type
984  "wrong! not implemented for this combination, please add "
985  "implementation");
986  // clang-format on
987 
988  if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
994  {
995  // HACK: cast pointer of x is bad
996  // TODO: remove this after compiler fix
997  *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
998  *c_style_pointer_cast<const int8_t*>(&x);
999  }
1000  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
1006  {
1007  // HACK: cast pointer of x is bad
1008  // TODO: remove this after compiler fix
1009  *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
1010  *c_style_pointer_cast<const int16_t*>(&x);
1011  }
1012  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
1018  {
1019  // HACK: cast pointer of x is bad
1020  // TODO: remove this after compiler fix
1021  *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
1022  *c_style_pointer_cast<const int32_t*>(&x);
1023  }
1024  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
1030  {
1031  // HACK: cast pointer of x is bad
1032  // TODO: remove this after compiler fix
1033  *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
1034  *c_style_pointer_cast<const int32x2_t*>(&x);
1035  }
1036  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
1042  {
1043  // HACK: cast pointer of x is bad
1044  // TODO: remove this after compiler fix
1045  *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
1046  *c_style_pointer_cast<const int32x4_t*>(&x);
1047  }
1048  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
1052  {
1053  // HACK: cast pointer of x is bad
1054  // TODO: remove this after compiler fix
1055  *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
1056  *c_style_pointer_cast<const int32_t*>(&x);
1057  }
1058  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
1062  {
1063  // HACK: cast pointer of x is bad
1064  // TODO: remove this after compiler fix
1065  *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
1066  *c_style_pointer_cast<const int32x2_t*>(&x);
1067  }
1068  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
1072  {
1073  // HACK: cast pointer of x is bad
1074  // TODO: remove this after compiler fix
1075  *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
1076  *c_style_pointer_cast<const int32x4_t*>(&x);
1077  }
1078  }
1079  }
1080  else
1081  {
1082  if(is_valid_element)
1083  {
1084 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
1085  X tmp = x;
1086 
1087  __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
1088 #else
1089  using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
1090  scalar_per_t_vector * scalar_per_x_vector>;
1091 
1092  *c_style_pointer_cast<buf_t*>(&p_data_[i]) = reinterpret_cast<const buf_t&>(x);
1093 #endif
1094  }
1095  }
1096  }
1097 
1098  // FIXME: remove
1099  CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
1100 
1101  // FIXME: remove
1102  CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
1103 };
1104 
1105 // Address Space: Vgpr
1106 // T may be scalar or vector
1107 // X may be scalar or vector
1108 // T and X have same scalar type
1109 // X contains multiple T
1110 // FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
1111 // transforms of tensor_view/Tensor
1112 template <typename T, typename BufferSizeType, bool InvalidElementUseNumericalZeroValue>
1113 struct buffer_view<address_space_enum::vgpr,
1114  T,
1115  BufferSizeType,
1116  InvalidElementUseNumericalZeroValue,
1118 {
1119  using type = T;
1120 
1121  T* p_data_ = nullptr;
1122  BufferSizeType buffer_size_;
1123  remove_cvref_t<T> invalid_element_value_ = T{0};
1124 
1126  : p_data_{}, buffer_size_{}, invalid_element_value_{}
1127  {
1128  }
1129 
1130  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
1131  : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
1132  {
1133  }
1134 
1135  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
1136  BufferSizeType buffer_size,
1137  T invalid_element_value)
1138  : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
1139  {
1140  }
1141 
1143 
1144  CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
1145  {
1146  return address_space_enum::vgpr;
1147  }
1148 
1149  // i is offset of T
1150  // FIXME: doesn't do is_valid check
1151  CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
1152 
1153  // i is offset of T
1154  // FIXME: doesn't do is_valid check
1155  CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
1156 
1157  // i is offset of T, not X. i should be aligned to X
1158  template <typename X,
1159  bool oob_conditional_check = true,
1160  typename std::enable_if<
1161  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
1162  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
1163  bool>::type = false>
1164  CK_TILE_DEVICE constexpr auto get(index_t i,
1165  index_t /*linear_offset*/,
1166  bool is_valid_element,
1168  {
1169  // X contains multiple T
1170  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
1171 
1172  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
1173 
1174  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
1175  "wrong! X should contain multiple T");
1176 
1177  if(is_valid_element)
1178  {
1179 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
1180  X tmp;
1181 
1182  __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
1183 
1184  return tmp;
1185 #else
1186  return *c_style_pointer_cast<const X*>(&p_data_[i]);
1187 #endif
1188  }
1189  else
1190  {
1191  if constexpr(InvalidElementUseNumericalZeroValue)
1192  {
1193  return X{numeric<remove_cvref_t<T>>::zero()};
1194  }
1195  else
1196  {
1197  return X{invalid_element_value_};
1198  }
1199  }
1200  }
1201 
1202  // i is offset of T, not X. i should be aligned to X
1203  template <memory_operation_enum Op,
1204  typename X,
1205  typename std::enable_if<
1206  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
1207  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
1208  bool>::type = false>
1209  CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
1210  {
1211  if constexpr(Op == memory_operation_enum::set)
1212  {
1213  this->template set<X>(i, linear_offset, is_valid_element, x);
1214  }
1215  // FIXME: remove memory_operation_enum::add
1216  else if constexpr(Op == memory_operation_enum::add)
1217  {
1218  auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
1219  this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
1220  }
1221  }
1222 
1223  // i is offset of T, not X. i should be aligned to X
1224  template <typename X,
1225  typename std::enable_if<
1226  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
1227  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
1228  bool>::type = false>
1229  CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
1230  {
1231  // X contains multiple T
1232  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
1233 
1234  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
1235 
1236  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
1237  "wrong! X should contain multiple T");
1238 
1239  if(is_valid_element)
1240  {
1241 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
1242  X tmp = x;
1243 
1244  __builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
1245 #else
1246  *c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
1247 #endif
1248  }
1249  }
1250 
1251  // FIXME: remove
1252  CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
1253 
1254  // FIXME: remove
1255  CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
1256 };
1257 
1258 template <address_space_enum BufferAddressSpace,
1260  typename T,
1261  typename BufferSizeType>
1262 CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size)
1263 {
1265 }
1266 
1267 template <address_space_enum BufferAddressSpace,
1269  typename T,
1270  typename BufferSizeType,
1271  typename X,
1272  typename std::enable_if<std::is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value,
1273  bool>::type = false>
1274 CK_TILE_HOST_DEVICE constexpr auto
1275 make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size, X invalid_element_value)
1276 {
1278  p, buffer_size, invalid_element_value};
1279 }
1280 
1281 // Generalized print function for all buffer_view variants
1282 template <address_space_enum BufferAddressSpace,
1283  typename T,
1284  typename BufferSizeType,
1285  bool InvalidElementUseNumericalZeroValue,
1286  amd_buffer_coherence_enum Coherence>
1287 CK_TILE_HOST_DEVICE void print(const buffer_view<BufferAddressSpace,
1288  T,
1289  BufferSizeType,
1290  InvalidElementUseNumericalZeroValue,
1291  Coherence>& bv)
1292 {
1293  printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ",
1294  address_space_to_string(BufferAddressSpace),
1295  static_cast<void*>(const_cast<remove_cvref_t<T>*>(bv.p_data_)));
1296  print(bv.buffer_size_);
1297  printf(", invalid_element_value_: ");
1298  print(bv.invalid_element_value_);
1299  printf("}");
1300 }
1301 
1302 } // namespace ck_tile
constexpr CK_TILE_HOST_DEVICE const char * address_space_to_string(address_space_enum addr_space)
Helper function to convert address space enum to string.
Definition: arch.hpp:301
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
ushort bfloat16_t
Definition: bfloat16.hpp:111
int8_t int8x16_t
Definition: vector_type.hpp:193
int8_t int8x4_t
Definition: vector_type.hpp:191
int8_t int8x8_t
Definition: vector_type.hpp:192
CK_TILE_DEVICE thread_buffer< T, N > amd_buffer_load_invalid_element_return_customized_value(const T *p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size, T customized_value)
Definition: amd_buffer_addressing.hpp:2580
int8_t int8_t
Definition: int8.hpp:20
amd_buffer_coherence_enum
Definition: amd_buffer_addressing.hpp:1404
CK_TILE_HOST_DEVICE T add(const T &a, const T &b)
Definition: generic_memory_space_atomic.hpp:16
constexpr CK_TILE_HOST_DEVICE auto make_buffer_view(T *__restrict__ p, BufferSizeType buffer_size)
Definition: buffer_view.hpp:1262
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d< BlockSize, YPerTile, XPerTile, VecSize, DistributionPattern, NumWaveGroups > &)
Definition: static_encoding_pattern.hpp:341
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int8_t pk_int4x4_t
Definition: vector_type.hpp:247
int8_t pk_int4x16_t
Definition: vector_type.hpp:249
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:84
int32_t int32_t
Definition: integer.hpp:10
int8_t int8x2_t
Definition: pk_int4.hpp:103
int32_t int32x4_t
Definition: vector_type.hpp:155
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr, uint32_t size=0xffffffff, ForceSGPR={})
Definition: amd_buffer_addressing.hpp:97
int8_t pk_int4x8_t
Definition: vector_type.hpp:248
_Float16 half_t
Definition: half.hpp:111
__device__ X atomic_max(X *p_dst, const X &x)
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:183
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size)
Definition: buffer_view.hpp:62
constexpr CK_TILE_DEVICE auto transpose_get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:144
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size, T invalid_element_value)
Definition: buffer_view.hpp:67
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:163
constexpr CK_TILE_DEVICE auto get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:96
constexpr CK_TILE_DEVICE auto transpose_get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:370
constexpr CK_TILE_DEVICE auto async_get_raw(remove_cvref_t< T > *smem, index_t i, index_t linear_offset, bool, bool_constant< pre_nop >={}) const
Definition: buffer_view.hpp:449
constexpr CK_TILE_DEVICE const T & operator[](index_t i) const
Definition: buffer_view.hpp:275
constexpr CK_TILE_DEVICE auto get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:288
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X &x, bool_constant< oob_conditional_check >={})
Definition: buffer_view.hpp:476
static constexpr CK_TILE_DEVICE address_space_enum get_address_space()
Definition: buffer_view.hpp:268
CK_TILE_DEVICE void update_raw(index_t i, index_t linear_offset, bool is_valid_element, const X &x, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: buffer_view.hpp:517
constexpr CK_TILE_DEVICE auto async_get(CK_TILE_LDS_ADDR remove_cvref_t< T > *smem, index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:416
constexpr CK_TILE_DEVICE auto get_raw(remove_cvref_t< X > &dst, index_t v_offset, index_t i_offset, bool is_valid_element, bool_constant< pre_nop >={}) const
Definition: buffer_view.hpp:390
CK_TILE_DEVICE void atomic_add(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:613
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size)
Definition: buffer_view.hpp:243
CK_TILE_DEVICE void set_raw(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:591
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:546
CK_TILE_DEVICE void atomic_max(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:708
CK_TILE_DEVICE void atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:677
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size, T invalid_element_value)
Definition: buffer_view.hpp:251
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:932
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size, T invalid_element_value)
Definition: buffer_view.hpp:777
constexpr CK_TILE_DEVICE auto get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:806
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:912
constexpr CK_TILE_DEVICE auto transpose_get([[maybe_unused]] index_t i, [[maybe_unused]] index_t linear_offset, bool is_valid_element) const
Definition: buffer_view.hpp:870
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size)
Definition: buffer_view.hpp:772
constexpr CK_TILE_DEVICE auto get_raw(remove_cvref_t< X > &dst, index_t v_offset, index_t i_offset, bool, bool_constant< pre_nop >={}) const
Definition: buffer_view.hpp:856
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size)
Definition: buffer_view.hpp:1130
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size, T invalid_element_value)
Definition: buffer_view.hpp:1135
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:1229
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:1209
constexpr CK_TILE_DEVICE auto get(index_t i, index_t, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:1164
Definition: buffer_view.hpp:35
Definition: integral_constant.hpp:13
Definition: numeric.hpp:81
Definition: numeric.hpp:18
Definition: pk_int4.hpp:21
Definition: amd_buffer_addressing.hpp:895
Definition: debug.hpp:67
Definition: vector_type.hpp:90