/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/buffer_view.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/buffer_view.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/buffer_view.hpp Source File
buffer_view.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 #if __clang_major__ >= 20
10 #else
12 #endif
22 
23 namespace ck_tile {
24 
25 // T may be scalar or vector
26 // X may be scalar or vector
27 // T and X have same scalar type
28 // X contains multiple T
29 // FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
30 // transforms of tensor_view/Tensor
31 // FIXME: amd_buffer_coherence_enum is only meaningful for buffer addressing. Need to split
32 // buffer_view definition for different memory address space (Global/GenericLds/Vgpr)
33 template <address_space_enum BufferAddressSpace,
34  typename T,
35  typename BufferSizeType,
36  bool InvalidElementUseNumericalZeroValue,
38 struct buffer_view;
39 
40 // Address Space: generic
41 // T may be scalar or vector
42 // X may be scalar or vector
43 // T and X have same scalar type
44 // X contains multiple T
45 // FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
46 // transforms of tensor_view/Tensor
47 template <typename T, typename BufferSizeType, bool InvalidElementUseNumericalZeroValue>
48 struct buffer_view<address_space_enum::generic,
49  T,
50  BufferSizeType,
51  InvalidElementUseNumericalZeroValue,
53 {
54  using type = T;
55 
56  T* p_data_ = nullptr;
57  BufferSizeType buffer_size_;
58  remove_cvref_t<T> invalid_element_value_ = T{0};
59 
61  : p_data_{}, buffer_size_{}, invalid_element_value_{}
62  {
63  }
64 
65  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
66  : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
67  {
68  }
69 
70  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
71  BufferSizeType buffer_size,
72  T invalid_element_value)
73  : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
74  {
75  }
76 
78 
79  CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
80  {
81  return address_space_enum::generic;
82  }
83 
84  // i is offset of T
85  // FIXME: doesn't do is_valid check
86  CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
87 
88  // i is offset of T
89  // FIXME: doesn't do is_valid check
90  CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
91 
92  // i is offset of T, not X. i should be aligned to X
93  template <typename X,
94  bool oob_conditional_check = true,
95  typename std::enable_if<
96  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
97  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
98  bool>::type = false>
99  CK_TILE_DEVICE constexpr auto get(index_t i,
100  index_t linear_offset,
101  bool is_valid_element,
103  {
104  // X contains multiple T
105  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
106 
107  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
108 
109  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
110  "wrong! X should contain multiple T");
111 
112  if(is_valid_element)
113  {
114 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
115  X tmp;
116 
117  __builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
118 
119  return tmp;
120 #else
121  return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
122 #endif
123  }
124  else
125  {
126  if constexpr(InvalidElementUseNumericalZeroValue)
127  {
128  return X{numeric<remove_cvref_t<T>>::zero()};
129  }
130  else
131  {
132  return X{invalid_element_value_};
133  }
134  }
135  }
136 
137  /*
138  In the generic address space, we do not support the transpose instruction in the buffer view.
139  Will report compilation error when developer wants to use it.
140  */
141  template <typename X,
142  bool oob_conditional_check = true,
143  typename std::enable_if<
144  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
145  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
146  bool>::type = false>
148  index_t linear_offset,
149  bool is_valid_element,
151  {
152  static_assert(false, "Error: transpose load not supported in global memory space.");
153  ignore = i;
154  ignore = linear_offset;
155  ignore = is_valid_element;
156  return;
157  }
158 
159  // i is offset of T, not X. i should be aligned to X
160  template <memory_operation_enum Op,
161  typename X,
162  typename std::enable_if<
163  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
164  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
165  bool>::type = false>
166  CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
167  {
168  if constexpr(Op == memory_operation_enum::set)
169  {
170  this->template set<X>(i, linear_offset, is_valid_element, x);
171  }
172  // FIXME: remove memory_operation_enum::add
173  else if constexpr(Op == memory_operation_enum::add)
174  {
175  auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
176  this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
177  }
178  }
179 
180  // i is offset of T, not X. i should be aligned to X
181  template <typename X,
182  typename std::enable_if<
183  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
184  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
185  bool>::type = false>
186  CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
187  {
188  // X contains multiple T
189  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
190 
191  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
192 
193  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
194  "wrong! X should contain multiple T");
195 
196  if(is_valid_element)
197  {
198 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
199  X tmp = x;
200 
201  __builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
202 #else
203  *c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
204 #endif
205  }
206  }
207 
208  // FIXME: remove
209  CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
210 
211  // FIXME: remove
212  CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
213 };
214 
215 // Address Space: Global
216 // T may be scalar or vector
217 // X may be scalar or vector
218 // T and X have same scalar type
219 // X contains multiple T
220 // FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
221 // transforms of tensor_view/Tensor
222 template <typename T,
223  typename BufferSizeType,
224  bool InvalidElementUseNumericalZeroValue,
225  amd_buffer_coherence_enum Coherence>
226 struct buffer_view<address_space_enum::global,
227  T,
228  BufferSizeType,
229  InvalidElementUseNumericalZeroValue,
230  Coherence>
231 {
232  using type = T;
233 
234  T* p_data_ = nullptr;
235  BufferSizeType buffer_size_;
237  remove_cvref_t<T> invalid_element_value_ = T{0};
238 
239  static constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
240 
242  : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{}
243  {
244  }
245 
246  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
247  : p_data_{p_data},
248  buffer_size_{buffer_size / PackedSize},
249  cached_buf_res_{0},
250  invalid_element_value_{0}
251  {
252  }
253 
254  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
255  BufferSizeType buffer_size,
256  T invalid_element_value)
257  : p_data_{p_data},
258  buffer_size_{buffer_size / PackedSize},
259  cached_buf_res_{0},
260  invalid_element_value_{invalid_element_value}
261  {
262  }
263 
264  // this is non constexpr intentially (will call some intrinsic internally)
265  // Must call for buffers that need *_raw load/store
267  {
268  cached_buf_res_ = make_wave_buffer_resource(p_data_, (buffer_size_) * sizeof(type));
269  }
270 
271  CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
272  {
273  return address_space_enum::global;
274  }
275 
276  // i is offset of T
277  // FIXME: doesn't do is_valid check
278  CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
279 
280  // i is offset of T
281  // FIXME: doesn't do is_valid check
282  CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
283 
284  // i is offset of T, not X. i should be aligned to X
285  template <typename X,
286  bool oob_conditional_check = true,
287  typename std::enable_if<
288  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
289  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
290  bool>::type = false>
291  CK_TILE_DEVICE constexpr auto get(index_t i,
292  index_t linear_offset,
293  bool is_valid_element,
295  {
296  // X contains multiple T
297  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
298 
299  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
300 
301  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
302  "wrong! X should contain multiple T");
303 
304 #if CK_TILE_USE_AMD_BUFFER_LOAD
305  bool constexpr use_amd_buffer_addressing = true;
306 #else
307  bool constexpr use_amd_buffer_addressing = false;
308 #endif
309 
310  if constexpr(use_amd_buffer_addressing)
311  {
312  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
313 
314  if constexpr(InvalidElementUseNumericalZeroValue)
315  {
316  return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
317  t_per_x,
318  Coherence,
319  oob_conditional_check>(
320  p_data_, i + linear_offset, is_valid_element, buffer_size_);
321  }
322  else
323  {
325  remove_cvref_t<T>,
326  t_per_x,
327  Coherence,
328  oob_conditional_check>(p_data_,
329  i + linear_offset,
330  is_valid_element,
331  buffer_size_,
332  invalid_element_value_);
333  }
334  }
335  else
336  {
337  if(is_valid_element)
338  {
339 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
340  X tmp;
341 
342  __builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
343 
344  return tmp;
345 #else
346  return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
347 #endif
348  }
349  else
350  {
351  if constexpr(InvalidElementUseNumericalZeroValue)
352  {
353  return X{numeric<remove_cvref_t<T>>::zero()};
354  }
355  else
356  {
357  return X{invalid_element_value_};
358  }
359  }
360  }
361  }
362 
363  /*
364  In the global memory address space, we do not support the transpose instruction in the buffer
365  view. Will report compilation error when developer wants to use it.
366  */
367  template <typename X,
368  bool oob_conditional_check = true,
369  typename std::enable_if<
370  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
371  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
372  bool>::type = false>
374  index_t linear_offset,
375  bool is_valid_element,
377  {
378  static_assert(false, "Error: transpose load not supported in global memory space.");
379  ignore = i;
380  ignore = linear_offset;
381  ignore = is_valid_element;
382  return;
383  }
384 
385  // i is offset of T, not X. i should be aligned to X
386  template <typename X,
387  bool oob_conditional_check = true,
388  bool pre_nop = false,
389  typename std::enable_if<
390  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
391  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
392  bool>::type = false>
394  index_t v_offset,
395  index_t i_offset,
396  bool is_valid_element,
397  bool_constant<pre_nop> = {}) const
398  {
399  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
400 
401  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
402 
403  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
404  "wrong! X should contain multiple T");
405 
406  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
407 
408  amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
409  dst, cached_buf_res_, v_offset, i_offset, is_valid_element, bool_constant<pre_nop>{});
410  }
411 
412  // i is offset of T, not X. i should be aligned to X
413  template <typename X,
414  bool oob_conditional_check = true,
415  typename std::enable_if<
416  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
417  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
418  bool>::type = false>
420  index_t i,
421  index_t linear_offset,
422  bool is_valid_element,
424  {
425  // X is vector of T
426  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
427  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
428 
429  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
430  "wrong! X should contain multiple T");
431 
432  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
433  const int32x4_t src_wave_buffer_resource =
434  make_wave_buffer_resource(p_data_, (buffer_size_) * sizeof(type));
435 
436  amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
437  smem,
438  src_wave_buffer_resource,
439  i,
440  linear_offset,
441  is_valid_element,
443  }
444 
445  // i is offset of T, not X. i should be aligned to X
446  template <typename X,
447  bool pre_nop = false,
448  typename std::enable_if<
449  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
450  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
451  bool>::type = false>
453  index_t i,
454  index_t linear_offset,
455  bool /*is_valid_element*/,
456  bool_constant<pre_nop> = {}) const
457  {
458  // X is vector of T
459  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
460  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
461 
462  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
463  "wrong! X should contain multiple T");
464 
465  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
466 
467  amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
468  smem, cached_buf_res_, i, linear_offset, bool_constant<pre_nop>{});
469  }
470 
471  // i is offset of T, not X. i should be aligned to X
472  template <memory_operation_enum Op,
473  typename X,
474  bool oob_conditional_check = true,
475  typename std::enable_if<
476  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
477  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
478  bool>::type = false>
480  index_t linear_offset,
481  bool is_valid_element,
482  const X& x,
484  {
485  if constexpr(Op == memory_operation_enum::set)
486  {
487  this->template set<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
488  }
489  else if constexpr(Op == memory_operation_enum::atomic_add)
490  {
491  this->template atomic_add<X, oob_conditional_check>(
492  i, linear_offset, is_valid_element, x);
493  }
494  else if constexpr(Op == memory_operation_enum::atomic_max)
495  {
496  this->template atomic_max<X, oob_conditional_check>(
497  i, linear_offset, is_valid_element, x);
498  }
499  // FIXME: remove memory_operation_enum::add
500  else if constexpr(Op == memory_operation_enum::add)
501  {
502  auto tmp =
503  this->template get<X, oob_conditional_check>(i, linear_offset, is_valid_element);
504  this->template set<X, oob_conditional_check>(
505  i, linear_offset, is_valid_element, x + tmp);
506  // tmp += x;
507  // this->template set<X>(i, is_valid_element, tmp);
508  }
509  }
510 
511  // i is offset of T, not X. i should be aligned to X
512  template <memory_operation_enum Op,
513  typename X,
514  bool oob_conditional_check = true,
515  bool pre_nop = false,
516  typename std::enable_if<
517  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
518  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
519  bool>::type = false>
521  index_t linear_offset,
522  bool is_valid_element,
523  const X& x,
526  {
527  if constexpr(Op == memory_operation_enum::set)
528  {
529  this->template set_raw<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
530  }
531  else if constexpr(Op == memory_operation_enum::atomic_add)
532  {
533  this->template atomic_add_raw<X, oob_conditional_check, pre_nop>(
534  i, linear_offset, is_valid_element, x);
535  }
536  else if constexpr(Op == memory_operation_enum::atomic_max)
537  {
538  // this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
539  }
540  }
541 
542  // i is offset of T, not X. i should be aligned to X
543  template <typename X,
544  bool oob_conditional_check = true,
545  typename std::enable_if<
546  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
547  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
548  bool>::type = false>
549  CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
550  {
551  // X contains multiple T
552  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
553 
554  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
555 
556  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
557  "wrong! X should contain multiple T");
558 
559 #if CK_TILE_USE_AMD_BUFFER_STORE
560  bool constexpr use_amd_buffer_addressing = true;
561 #else
562  bool constexpr use_amd_buffer_addressing = false;
563 #endif
564 
565  if constexpr(use_amd_buffer_addressing)
566  {
567  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
568 
569  amd_buffer_store<remove_cvref_t<T>, t_per_x, Coherence>(
570  x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
571  }
572  else
573  {
574  if(is_valid_element)
575  {
576 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
577  X tmp = x;
578 
579  __builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
580 #else
581  *c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
582 #endif
583  }
584  }
585  }
586 
587  // i is offset of T, not X. i should be aligned to X
588  template <typename X,
589  bool oob_conditional_check = true,
590  typename std::enable_if<
591  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
592  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
593  bool>::type = false>
594  CK_TILE_DEVICE void set_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
595  {
596  // X contains multiple T
597  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
598 
599  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
600 
601  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
602  "wrong! X should contain multiple T");
603 
604  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
605  amd_buffer_store_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
606  x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
607  }
608 
609  template <typename X,
610  bool oob_conditional_check = true,
611  typename std::enable_if<
612  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
613  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
614  bool>::type = false>
615  CK_TILE_DEVICE void
616  atomic_add(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
617  {
618  using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
619 
620  // X contains multiple T
621  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
622 
623  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
624 
625  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
626  "wrong! X should contain multiple T");
627 
628  static_assert(get_address_space() == address_space_enum::global, "only support global mem");
629 
630 #if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
631  bool constexpr use_amd_buffer_addressing =
632  std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
633  std::is_same_v<remove_cvref_t<scalar_t>, float> ||
634  (std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
635 #elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
636  bool constexpr use_amd_buffer_addressing =
637  std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
638 #elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
639  bool constexpr use_amd_buffer_addressing =
640  std::is_same_v<remove_cvref_t<scalar_t>, float> ||
641  (std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
642 #else
643  bool constexpr use_amd_buffer_addressing = false;
644 #endif
645 
646  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
647 
648  if constexpr(use_amd_buffer_addressing)
649  {
650  amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
651  x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
652  }
653  else
654  {
655  if(is_valid_element)
656  {
657  atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
658  }
659  }
660  }
661 
662  template <typename X,
663  bool oob_conditional_check = true,
664  bool pre_nop = true,
665  typename std::enable_if<
666  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
667  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
668  bool>::type = false>
669  CK_TILE_DEVICE void
670  atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
671  {
672  // using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
673 
674  // X contains multiple T
675  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
676 
677  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
678 
679  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
680  "wrong! X should contain multiple T");
681 
682  static_assert(get_address_space() == address_space_enum::global, "only support global mem");
683 
684  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
685 
686  amd_buffer_atomic_add_raw<remove_cvref_t<T>,
687  t_per_x,
688  Coherence,
689  oob_conditional_check,
690  pre_nop>(
691  x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
692  }
693 
694  template <typename X,
695  bool oob_conditional_check = true,
696  typename std::enable_if<
697  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
698  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
699  bool>::type = false>
700  CK_TILE_DEVICE void
701  atomic_max(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
702  {
703  // X contains multiple T
704  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
705 
706  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
707 
708  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
709  "wrong! X should contain multiple T");
710 
711  static_assert(get_address_space() == address_space_enum::global, "only support global mem");
712 
713 #if CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
714  using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
715  bool constexpr use_amd_buffer_addressing = std::is_same_v<remove_cvref_t<scalar_t>, double>;
716 #else
717  bool constexpr use_amd_buffer_addressing = false;
718 #endif
719 
720  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
721 
722  if constexpr(use_amd_buffer_addressing)
723  {
724  amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
725  x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
726  }
727  else if(is_valid_element)
728  {
729  atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
730  }
731  }
732 
733  // FIXME: remove
734  CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
735 
736  // FIXME: remove
737  CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
738 };
739 
740 // Address Space: LDS
741 // T may be scalar or vector
742 // X may be scalar or vector
743 // T and X have same scalar type
744 // X contains multiple T
745 // FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
746 // transforms of tensor_view/Tensor
747 template <typename T, typename BufferSizeType, bool InvalidElementUseNumericalZeroValue>
748 struct buffer_view<address_space_enum::lds,
749  T,
750  BufferSizeType,
751  InvalidElementUseNumericalZeroValue,
753 {
754  using type = T;
755 
756  T* p_data_ = nullptr;
757  BufferSizeType buffer_size_;
758  remove_cvref_t<T> invalid_element_value_ = T{0};
759 
761  : p_data_{}, buffer_size_{}, invalid_element_value_{}
762  {
763  }
764 
765  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
766  : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
767  {
768  }
769 
770  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
771  BufferSizeType buffer_size,
772  T invalid_element_value)
773  : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
774  {
775  }
776 
778 
779  CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
780  {
781  return address_space_enum::lds;
782  }
783 
784  // i is offset of T
785  // FIXME: doesn't do is_valid check
786  CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
787 
788  // i is offset of T
789  // FIXME: doesn't do is_valid check
790  CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
791 
792  // i is offset of T, not X. i should be aligned to X
793  template <typename X,
794  bool oob_conditional_check = true,
795  typename std::enable_if<
796  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
797  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
798  bool>::type = false>
799  CK_TILE_DEVICE constexpr auto get(index_t i,
800  index_t linear_offset,
801  bool is_valid_element,
803  {
804  // X contains multiple T
805  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
806 
807  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
808 
809  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
810  "wrong! X should contain multiple T");
811 
812  if(is_valid_element)
813  {
814 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
815  X tmp;
816 
817  __builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
818 
819  return tmp;
820 #else
821  using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
822  scalar_per_t_vector * scalar_per_x_vector>;
823  // using buf_t = ushort __attribute__((ext_vector_type(8)));
824  auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i + linear_offset]);
825  return bit_cast<X>(rtn);
826 #endif
827  }
828  else
829  {
830  if constexpr(InvalidElementUseNumericalZeroValue)
831  {
832  return X{numeric<remove_cvref_t<T>>::zero()};
833  }
834  else
835  {
836  return X{invalid_element_value_};
837  }
838  }
839  }
840 
841  // i is offset of T, not X. i should be aligned to X
842  template <typename X,
843  bool oob_conditional_check = true,
844  bool pre_nop = false,
845  typename std::enable_if<
846  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
847  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
848  bool>::type = false>
850  index_t v_offset,
851  index_t i_offset,
852  bool /*is_valid_element*/,
853  bool_constant<pre_nop> = {}) const
854  {
855  smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
856  }
857 
858  template <typename X,
859  typename std::enable_if<
860  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
861  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
862  bool>::type = false>
863  CK_TILE_DEVICE constexpr auto transpose_get([[maybe_unused]] index_t i,
864  [[maybe_unused]] index_t linear_offset,
865  bool is_valid_element) const
866  {
867  // X contains multiple T
868  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
869 
870  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
871 
872  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
873  "wrong! X should contain multiple T");
874 
875  if(is_valid_element)
876  {
877 #if defined(__gfx950__)
878  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
879  constexpr address_space_enum addr_space = get_address_space();
880  return amd_transpose_load_to_vgpr<remove_cvref_t<T>, t_per_x, addr_space>(
881  p_data_ + i + linear_offset);
882 #else
883  return X{numeric<remove_cvref_t<T>>::zero()};
884 #endif
885  }
886  else
887  {
888  if constexpr(InvalidElementUseNumericalZeroValue)
889  {
890  return X{numeric<remove_cvref_t<T>>::zero()};
891  }
892  else
893  {
894  return X{invalid_element_value_};
895  }
896  }
897  }
898 
899  // i is offset of T, not X. i should be aligned to X
900  template <memory_operation_enum Op,
901  typename X,
902  typename std::enable_if<
903  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
904  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
905  bool>::type = false>
906  CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
907  {
908  if constexpr(Op == memory_operation_enum::set)
909  {
910  this->template set<X>(i, linear_offset, is_valid_element, x);
911  }
912  // FIXME: remove memory_operation_enum::add
913  else if constexpr(Op == memory_operation_enum::add)
914  {
915  auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
916  this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
917  }
918  }
919 
920  // i is offset of T, not X. i should be aligned to X
921  template <typename X,
922  typename std::enable_if<
923  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
924  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
925  bool>::type = false>
926  CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
927  {
928  // X contains multiple T
929  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
930 
931  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
932 
933  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
934  "wrong! X should contain multiple T");
935 
936 #if CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
937  bool constexpr workaround_int8_ds_write_issue = true;
938 #else
939  bool constexpr workaround_int8_ds_write_issue = false;
940 #endif
941 
942  i += linear_offset; // simplicity
943  if constexpr(std::is_same_v<typename vector_traits<remove_cvref_t<T>>::scalar_type,
944  int8_t> &&
945  workaround_int8_ds_write_issue)
946  {
947  if(is_valid_element)
948  {
949  // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
950  // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
951  // ds_write_b128
952  // TODO: remove this after compiler fix
953  // clang-format off
954  static_assert(
963  // int8 on thread buffer
969  // ext_vector_type for pk_int4 must use int8_t as type
978  "wrong! not implemented for this combination, please add "
979  "implementation");
980  // clang-format on
981 
982  if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
988  {
989  // HACK: cast pointer of x is bad
990  // TODO: remove this after compiler fix
991  *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
992  *c_style_pointer_cast<const int8_t*>(&x);
993  }
994  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
1000  {
1001  // HACK: cast pointer of x is bad
1002  // TODO: remove this after compiler fix
1003  *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
1004  *c_style_pointer_cast<const int16_t*>(&x);
1005  }
1006  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
1012  {
1013  // HACK: cast pointer of x is bad
1014  // TODO: remove this after compiler fix
1015  *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
1016  *c_style_pointer_cast<const int32_t*>(&x);
1017  }
1018  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
1024  {
1025  // HACK: cast pointer of x is bad
1026  // TODO: remove this after compiler fix
1027  *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
1028  *c_style_pointer_cast<const int32x2_t*>(&x);
1029  }
1030  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
1036  {
1037  // HACK: cast pointer of x is bad
1038  // TODO: remove this after compiler fix
1039  *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
1040  *c_style_pointer_cast<const int32x4_t*>(&x);
1041  }
1042  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
1046  {
1047  // HACK: cast pointer of x is bad
1048  // TODO: remove this after compiler fix
1049  *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
1050  *c_style_pointer_cast<const int32_t*>(&x);
1051  }
1052  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
1056  {
1057  // HACK: cast pointer of x is bad
1058  // TODO: remove this after compiler fix
1059  *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
1060  *c_style_pointer_cast<const int32x2_t*>(&x);
1061  }
1062  else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
1066  {
1067  // HACK: cast pointer of x is bad
1068  // TODO: remove this after compiler fix
1069  *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
1070  *c_style_pointer_cast<const int32x4_t*>(&x);
1071  }
1072  }
1073  }
1074  else
1075  {
1076  if(is_valid_element)
1077  {
1078 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
1079  X tmp = x;
1080 
1081  __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
1082 #else
1083  using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
1084  scalar_per_t_vector * scalar_per_x_vector>;
1085 
1086  *c_style_pointer_cast<buf_t*>(&p_data_[i]) = reinterpret_cast<const buf_t&>(x);
1087 #endif
1088  }
1089  }
1090  }
1091 
1092  // FIXME: remove
1093  CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
1094 
1095  // FIXME: remove
1096  CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
1097 };
1098 
1099 // Address Space: Vgpr
1100 // T may be scalar or vector
1101 // X may be scalar or vector
1102 // T and X have same scalar type
1103 // X contains multiple T
1104 // FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
1105 // transforms of tensor_view/Tensor
1106 template <typename T, typename BufferSizeType, bool InvalidElementUseNumericalZeroValue>
1107 struct buffer_view<address_space_enum::vgpr,
1108  T,
1109  BufferSizeType,
1110  InvalidElementUseNumericalZeroValue,
1112 {
1113  using type = T;
1114 
1115  T* p_data_ = nullptr;
1116  BufferSizeType buffer_size_;
1117  remove_cvref_t<T> invalid_element_value_ = T{0};
1118 
1120  : p_data_{}, buffer_size_{}, invalid_element_value_{}
1121  {
1122  }
1123 
1124  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
1125  : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
1126  {
1127  }
1128 
1129  CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
1130  BufferSizeType buffer_size,
1131  T invalid_element_value)
1132  : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
1133  {
1134  }
1135 
1137 
1138  CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
1139  {
1140  return address_space_enum::vgpr;
1141  }
1142 
1143  // i is offset of T
1144  // FIXME: doesn't do is_valid check
1145  CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
1146 
1147  // i is offset of T
1148  // FIXME: doesn't do is_valid check
1149  CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
1150 
1151  // i is offset of T, not X. i should be aligned to X
1152  template <typename X,
1153  bool oob_conditional_check = true,
1154  typename std::enable_if<
1155  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
1156  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
1157  bool>::type = false>
1158  CK_TILE_DEVICE constexpr auto get(index_t i,
1159  index_t /*linear_offset*/,
1160  bool is_valid_element,
1162  {
1163  // X contains multiple T
1164  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
1165 
1166  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
1167 
1168  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
1169  "wrong! X should contain multiple T");
1170 
1171  if(is_valid_element)
1172  {
1173 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
1174  X tmp;
1175 
1176  __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
1177 
1178  return tmp;
1179 #else
1180  return *c_style_pointer_cast<const X*>(&p_data_[i]);
1181 #endif
1182  }
1183  else
1184  {
1185  if constexpr(InvalidElementUseNumericalZeroValue)
1186  {
1187  return X{numeric<remove_cvref_t<T>>::zero()};
1188  }
1189  else
1190  {
1191  return X{invalid_element_value_};
1192  }
1193  }
1194  }
1195 
1196  // i is offset of T, not X. i should be aligned to X
1197  template <memory_operation_enum Op,
1198  typename X,
1199  typename std::enable_if<
1200  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
1201  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
1202  bool>::type = false>
1203  CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
1204  {
1205  if constexpr(Op == memory_operation_enum::set)
1206  {
1207  this->template set<X>(i, linear_offset, is_valid_element, x);
1208  }
1209  // FIXME: remove memory_operation_enum::add
1210  else if constexpr(Op == memory_operation_enum::add)
1211  {
1212  auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
1213  this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
1214  }
1215  }
1216 
1217  // i is offset of T, not X. i should be aligned to X
1218  template <typename X,
1219  typename std::enable_if<
1220  std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
1221  typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
1222  bool>::type = false>
1223  CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
1224  {
1225  // X contains multiple T
1226  constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
1227 
1228  constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
1229 
1230  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
1231  "wrong! X should contain multiple T");
1232 
1233  if(is_valid_element)
1234  {
1235 #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
1236  X tmp = x;
1237 
1238  __builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
1239 #else
1240  *c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
1241 #endif
1242  }
1243  }
1244 
1245  // FIXME: remove
1246  CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
1247 
1248  // FIXME: remove
1249  CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
1250 };
1251 
1252 template <address_space_enum BufferAddressSpace,
1254  typename T,
1255  typename BufferSizeType>
1256 CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size)
1257 {
1259 }
1260 
1261 template <address_space_enum BufferAddressSpace,
1263  typename T,
1264  typename BufferSizeType,
1265  typename X,
1266  typename std::enable_if<std::is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value,
1267  bool>::type = false>
1268 CK_TILE_HOST_DEVICE constexpr auto
1269 make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size, X invalid_element_value)
1270 {
1272  p, buffer_size, invalid_element_value};
1273 }
1274 
1275 // Generalized print function for all buffer_view variants
1276 template <address_space_enum BufferAddressSpace,
1277  typename T,
1278  typename BufferSizeType,
1279  bool InvalidElementUseNumericalZeroValue,
1280  amd_buffer_coherence_enum Coherence>
1281 CK_TILE_HOST_DEVICE void print(const buffer_view<BufferAddressSpace,
1282  T,
1283  BufferSizeType,
1284  InvalidElementUseNumericalZeroValue,
1285  Coherence>& bv)
1286 {
1287  printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ",
1288  address_space_to_string(BufferAddressSpace),
1289  static_cast<void*>(const_cast<remove_cvref_t<T>*>(bv.p_data_)));
1290  print(bv.buffer_size_);
1291  printf(", invalid_element_value_: ");
1292  print(bv.invalid_element_value_);
1293  printf("}");
1294 }
1295 
1296 } // namespace ck_tile
constexpr CK_TILE_HOST_DEVICE const char * address_space_to_string(address_space_enum addr_space)
Helper function to convert address space enum to string.
Definition: arch.hpp:246
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
int8_t int8x16_t
Definition: vector_type.hpp:182
int8_t int8x4_t
Definition: vector_type.hpp:180
int8_t int8x8_t
Definition: vector_type.hpp:181
CK_TILE_DEVICE thread_buffer< T, N > amd_buffer_load_invalid_element_return_customized_value(const T *p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size, T customized_value)
Definition: amd_buffer_addressing.hpp:2471
int8_t int8_t
Definition: int8.hpp:20
amd_buffer_coherence_enum
Definition: amd_buffer_addressing.hpp:1332
CK_TILE_HOST_DEVICE T add(const T &a, const T &b)
Definition: generic_memory_space_atomic.hpp:16
constexpr CK_TILE_HOST_DEVICE auto make_buffer_view(T *__restrict__ p, BufferSizeType buffer_size)
Definition: buffer_view.hpp:1256
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d< BlockSize, YPerTile, XPerTile, VecSize, DistributionPattern, NumWaveGroups > &)
Definition: static_encoding_pattern.hpp:342
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int8_t pk_int4x4_t
Definition: vector_type.hpp:236
int8_t pk_int4x16_t
Definition: vector_type.hpp:238
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:83
int32_t int32_t
Definition: integer.hpp:10
int8_t int8x2_t
Definition: pk_int4.hpp:103
int32_t int32x4_t
Definition: vector_type.hpp:144
int8_t pk_int4x8_t
Definition: vector_type.hpp:237
_Float16 half_t
Definition: half.hpp:111
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr, uint32_t size=0xffffffff)
Definition: amd_buffer_addressing.hpp:40
__device__ X atomic_max(X *p_dst, const X &x)
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:186
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size)
Definition: buffer_view.hpp:65
constexpr CK_TILE_DEVICE auto transpose_get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:147
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size, T invalid_element_value)
Definition: buffer_view.hpp:70
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:166
constexpr CK_TILE_DEVICE auto get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:99
constexpr CK_TILE_DEVICE auto transpose_get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:373
constexpr CK_TILE_DEVICE auto async_get_raw(remove_cvref_t< T > *smem, index_t i, index_t linear_offset, bool, bool_constant< pre_nop >={}) const
Definition: buffer_view.hpp:452
constexpr CK_TILE_DEVICE const T & operator[](index_t i) const
Definition: buffer_view.hpp:278
constexpr CK_TILE_DEVICE auto get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:291
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X &x, bool_constant< oob_conditional_check >={})
Definition: buffer_view.hpp:479
static constexpr CK_TILE_DEVICE address_space_enum get_address_space()
Definition: buffer_view.hpp:271
CK_TILE_DEVICE void update_raw(index_t i, index_t linear_offset, bool is_valid_element, const X &x, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: buffer_view.hpp:520
constexpr CK_TILE_DEVICE auto async_get(CK_TILE_LDS_ADDR remove_cvref_t< T > *smem, index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:419
constexpr CK_TILE_DEVICE auto get_raw(remove_cvref_t< X > &dst, index_t v_offset, index_t i_offset, bool is_valid_element, bool_constant< pre_nop >={}) const
Definition: buffer_view.hpp:393
CK_TILE_DEVICE void atomic_add(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:616
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size)
Definition: buffer_view.hpp:246
CK_TILE_DEVICE void set_raw(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:594
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:549
CK_TILE_DEVICE void atomic_max(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:701
CK_TILE_DEVICE void atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:670
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size, T invalid_element_value)
Definition: buffer_view.hpp:254
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:926
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size, T invalid_element_value)
Definition: buffer_view.hpp:770
constexpr CK_TILE_DEVICE auto get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:799
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:906
constexpr CK_TILE_DEVICE auto transpose_get([[maybe_unused]] index_t i, [[maybe_unused]] index_t linear_offset, bool is_valid_element) const
Definition: buffer_view.hpp:863
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size)
Definition: buffer_view.hpp:765
constexpr CK_TILE_DEVICE auto get_raw(remove_cvref_t< X > &dst, index_t v_offset, index_t i_offset, bool, bool_constant< pre_nop >={}) const
Definition: buffer_view.hpp:849
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size)
Definition: buffer_view.hpp:1124
constexpr CK_TILE_HOST_DEVICE buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size, T invalid_element_value)
Definition: buffer_view.hpp:1129
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:1223
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition: buffer_view.hpp:1203
constexpr CK_TILE_DEVICE auto get(index_t i, index_t, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition: buffer_view.hpp:1158
Definition: buffer_view.hpp:38
Definition: integral_constant.hpp:13
Definition: numeric.hpp:81
Definition: numeric.hpp:18
Definition: pk_int4.hpp:21
Definition: amd_buffer_addressing.hpp:832
Definition: debug.hpp:67
Definition: vector_type.hpp:89