/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_buffer_addressing_builtins.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_buffer_addressing_builtins.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_buffer_addressing_builtins.hpp Source File
amd_buffer_addressing_builtins.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 #include "data_type.hpp"
7 
8 namespace ck {
9 
10 template <typename T>
11 union BufferResource
12 {
13  __device__ constexpr BufferResource() : content{} {}
14 
15  // 128 bit SGPRs to supply buffer resource in buffer instructions
16  // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
21 };
22 
23 template <typename T>
24 __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
25 {
26  BufferResource<T> wave_buffer_resource;
27 
28  // wavewise base address (64 bit)
29  wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
30  // wavewise range (32 bit)
31  wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
32  // wavewise setting (32 bit)
33  wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
34 
35  return wave_buffer_resource.content;
36 }
37 
38 template <typename T>
40 {
41  BufferResource<T> wave_buffer_resource;
42 
43  // wavewise base address (64 bit)
44  wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
45  // wavewise range (32 bit)
46  wave_buffer_resource.range(Number<2>{}) = 0xffffffff; // max possible range
47  // wavewise setting (32 bit)
48  wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
49 
50  return wave_buffer_resource.content;
51 }
52 
53 template <typename T>
54 __device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_new(T* p_wave,
55  index_t element_space_size)
56 {
57  // wavewise base address (64 bit)
58  auto p = const_cast<remove_cv_t<T>*>(p_wave);
59  int32_t stride = 0;
60  int32_t num = element_space_size * sizeof(T);
61  auto flags = CK_BUFFER_RESOURCE_3RD_DWORD;
62 
63  return __builtin_amdgcn_make_buffer_rsrc(p, stride, num, flags);
64 }
65 
66 template <typename T>
67 __device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_with_default_range_new(T* p_wave)
68 {
69  // wavewise base address (64 bit)
70  auto p = const_cast<remove_cv_t<T>*>(p_wave);
71  int32_t stride = 0;
72  int32_t num = 0xffffffff;
73  auto flags = CK_BUFFER_RESOURCE_3RD_DWORD;
74 
75  return __builtin_amdgcn_make_buffer_rsrc(p, stride, num, flags);
76 }
77 
78 // buffer atomic-add fp16
80  half2_t vdata,
81  int32x4_t rsrc,
82  index_t voffset,
83  index_t soffset,
84  index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
85 
86 // buffer atomic-add i32
88  int32_t vdata,
89  int32x4_t rsrc,
90  index_t voffset,
91  index_t soffset,
92  index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
93 
94 // buffer atomic-add fp32
96  float vdata,
97  int32x4_t rsrc,
98  index_t voffset,
99  index_t soffset,
100  index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
101 
102 // buffer atomic-add fp32
103 __device__ double
105  int32x4_t rsrc, // dst_wave_buffer_resource
106  int voffset, // dst_thread_addr_offset
107  int soffset, // dst_wave_addr_offset
108  int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
109 
110 template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
111 __device__ typename vector_type<int8_t, N>::type
112 amd_buffer_load_impl_raw(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
113  index_t src_thread_addr_offset,
114  index_t src_wave_addr_offset)
115 {
116  static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
117  "wrong! not implemented");
118 
119  if constexpr(N == 1)
120  {
121  return __builtin_amdgcn_raw_buffer_load_b8(src_wave_buffer_resource,
122  src_thread_addr_offset,
123  src_wave_addr_offset,
124  static_cast<index_t>(coherence));
125  }
126  else if constexpr(N == 2)
127  {
128 
129  int16_t tmp = __builtin_amdgcn_raw_buffer_load_b16(src_wave_buffer_resource,
130  src_thread_addr_offset,
131  src_wave_addr_offset,
132  static_cast<index_t>(coherence));
133 
134  return bit_cast<int8x2_t>(tmp);
135  }
136  else if constexpr(N == 4)
137  {
138  int32_t tmp = __builtin_amdgcn_raw_buffer_load_b32(src_wave_buffer_resource,
139  src_thread_addr_offset,
140  src_wave_addr_offset,
141  static_cast<index_t>(coherence));
142 
143  return bit_cast<int8x4_t>(tmp);
144  }
145  else if constexpr(N == 8)
146  {
147  int32x2_t tmp = __builtin_amdgcn_raw_buffer_load_b64(src_wave_buffer_resource,
148  src_thread_addr_offset,
149  src_wave_addr_offset,
150  static_cast<index_t>(coherence));
151 
152  return bit_cast<int8x8_t>(tmp);
153  }
154  else if constexpr(N == 16)
155  {
156  int32x4_t tmp = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
157  src_thread_addr_offset,
158  src_wave_addr_offset,
159  static_cast<index_t>(coherence));
160  return bit_cast<int8x16_t>(tmp);
161  }
162  else if constexpr(N == 32)
163  {
164  int32x4_t tmp0 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
165  src_thread_addr_offset,
166  src_wave_addr_offset,
167  static_cast<index_t>(coherence));
168  int32x4_t tmp1 =
169  __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
170  src_thread_addr_offset,
171  src_wave_addr_offset + 4 * sizeof(int32_t),
172  static_cast<index_t>(coherence));
174 
175  tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
176  tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
177 
178  return bit_cast<int8x32_t>(tmp);
179  }
180  else if constexpr(N == 64)
181  {
182  int32x4_t tmp0 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
183  src_thread_addr_offset,
184  src_wave_addr_offset,
185  static_cast<index_t>(coherence));
186  int32x4_t tmp1 =
187  __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
188  src_thread_addr_offset,
189  src_wave_addr_offset + 4 * sizeof(int32_t),
190  static_cast<index_t>(coherence));
191  int32x4_t tmp2 =
192  __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
193  src_thread_addr_offset,
194  src_wave_addr_offset + 8 * sizeof(int32_t),
195  static_cast<index_t>(coherence));
196  int32x4_t tmp3 =
197  __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
198  src_thread_addr_offset,
199  src_wave_addr_offset + 12 * sizeof(int32_t),
200  static_cast<index_t>(coherence));
201 
203 
204  tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
205  tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
206  tmp.AsType<int32x4_t>()(Number<2>{}) = tmp2;
207  tmp.AsType<int32x4_t>()(Number<3>{}) = tmp3;
208 
209  return bit_cast<int8x64_t>(tmp);
210  }
211 }
212 
213 template <typename T,
214  index_t N,
216 __device__ typename vector_type<T, N>::type
217 amd_buffer_load_impl(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
218  index_t src_thread_addr_offset,
219  index_t src_wave_addr_offset)
220 {
221  static_assert(
222  (is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
223  (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
224  (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
225  (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
226  (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
227  (is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
228  (is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
229  (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
230  (is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
231  (is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
232  "wrong! not implemented");
233 
234  using r_t = typename vector_type<T, N>::type;
235  auto raw_data = amd_buffer_load_impl_raw<sizeof(T) * N, coherence>(
236  src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset);
237  return bit_cast<r_t>(raw_data);
238 }
239 
240 template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
241 __device__ void
243  __amdgpu_buffer_rsrc_t dst_wave_buffer_resource,
244  index_t dst_thread_addr_offset,
245  index_t dst_wave_addr_offset)
246 {
247  static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
248  "wrong! not implemented");
249 
250  if constexpr(N == 1)
251  {
252  __builtin_amdgcn_raw_buffer_store_b8(src_thread_data,
253  dst_wave_buffer_resource,
254  dst_thread_addr_offset,
255  dst_wave_addr_offset,
256  static_cast<index_t>(coherence));
257  }
258  else if constexpr(N == 2)
259  {
260 
261  __builtin_amdgcn_raw_buffer_store_b16(bit_cast<int16_t>(src_thread_data),
262  dst_wave_buffer_resource,
263  dst_thread_addr_offset,
264  dst_wave_addr_offset,
265  static_cast<index_t>(coherence));
266  }
267  else if constexpr(N == 4)
268  {
269  __builtin_amdgcn_raw_buffer_store_b32(bit_cast<int32_t>(src_thread_data),
270  dst_wave_buffer_resource,
271  dst_thread_addr_offset,
272  dst_wave_addr_offset,
273  static_cast<index_t>(coherence));
274  }
275  else if constexpr(N == 8)
276  {
277  __builtin_amdgcn_raw_buffer_store_b64(bit_cast<int32x2_t>(src_thread_data),
278  dst_wave_buffer_resource,
279  dst_thread_addr_offset,
280  dst_wave_addr_offset,
281  static_cast<index_t>(coherence));
282  }
283  else if constexpr(N == 16)
284  {
285  __builtin_amdgcn_raw_buffer_store_b128(bit_cast<int32x4_t>(src_thread_data),
286  dst_wave_buffer_resource,
287  dst_thread_addr_offset,
288  dst_wave_addr_offset,
289  static_cast<index_t>(coherence));
290  }
291  else if constexpr(N == 32)
292  {
293  vector_type<int32_t, 8> tmp{bit_cast<int32x8_t>(src_thread_data)};
294 
295  __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<0>{}],
296  dst_wave_buffer_resource,
297  dst_thread_addr_offset,
298  dst_wave_addr_offset,
299  static_cast<index_t>(coherence));
300 
301  __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<1>{}],
302  dst_wave_buffer_resource,
303  dst_thread_addr_offset,
304  dst_wave_addr_offset + sizeof(int32_t) * 4,
305  static_cast<index_t>(coherence));
306  }
307  else if constexpr(N == 64)
308  {
309  vector_type<int32_t, 16> tmp{bit_cast<int32x16_t>(src_thread_data)};
310 
311  __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<0>{}],
312  dst_wave_buffer_resource,
313  dst_thread_addr_offset,
314  dst_wave_addr_offset,
315  static_cast<index_t>(coherence));
316 
317  __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<1>{}],
318  dst_wave_buffer_resource,
319  dst_thread_addr_offset,
320  dst_wave_addr_offset + sizeof(int32_t) * 4,
321  static_cast<index_t>(coherence));
322 
323  __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<2>{}],
324  dst_wave_buffer_resource,
325  dst_thread_addr_offset,
326  dst_wave_addr_offset + sizeof(int32_t) * 8,
327  static_cast<index_t>(coherence));
328 
329  __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<3>{}],
330  dst_wave_buffer_resource,
331  dst_thread_addr_offset,
332  dst_wave_addr_offset + sizeof(int32_t) * 12,
333  static_cast<index_t>(coherence));
334  }
335 }
336 
337 template <typename T,
338  index_t N,
340 __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
341  __amdgpu_buffer_rsrc_t dst_wave_buffer_resource,
342  index_t dst_thread_addr_offset,
343  index_t dst_wave_addr_offset)
344 {
345  static_assert(
346  (is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
347  (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
348  (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
349  (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
350  (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
351  (is_same<T, f8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
352  (is_same<T, bf8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
354  (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
355  (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
356  "wrong! not implemented");
357 
358  using r_t = typename vector_type<int8_t, sizeof(T) * N>::type;
359 
360  amd_buffer_store_impl_raw<sizeof(T) * N, coherence>(bit_cast<r_t>(src_thread_data),
361  dst_wave_buffer_resource,
362  dst_thread_addr_offset,
363  dst_wave_addr_offset);
364 }
365 
366 template <typename T, index_t N>
367 __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
368  T* addr)
369 {
370  static_assert((is_same<T, bhalf_t>::value && (N == 2 || N == 4 || N == 8)) ||
371  (is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)),
372  "wrong! not implemented");
373 
374  if constexpr(is_same<T, half_t>::value)
375  {
376  vector_type<half_t, N> tmp{src_thread_data};
377  static_for<0, N / 2, 1>{}([&](auto i) {
378  __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i,
379  tmp.template AsType<half2_t>()[i]);
380  });
381  }
382 #if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__)
383  else if constexpr(is_same<T, bhalf_t>::value)
384  {
385  vector_type<bhalf_t, N> tmp{src_thread_data};
386  static_for<0, N / 2, 1>{}([&](auto i) {
387  __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i,
388  tmp.template AsType<bhalf2_t>()[i]);
389  });
390  }
391 #endif
392 }
393 
394 template <typename T, index_t N>
395 __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
396  int32x4_t dst_wave_buffer_resource,
397  index_t dst_thread_addr_offset,
398  index_t dst_wave_addr_offset)
399 {
400  static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
401  (is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)) ||
402  (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
403  "wrong! not implemented");
404 
405  if constexpr(is_same<T, float>::value)
406  {
407  if constexpr(N == 1)
408  {
410  dst_wave_buffer_resource,
411  dst_thread_addr_offset,
412  dst_wave_addr_offset,
413  0);
414  }
415  else if constexpr(N == 2)
416  {
417  vector_type<float, 2> tmp{src_thread_data};
418 
419  llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
420  dst_wave_buffer_resource,
421  dst_thread_addr_offset,
422  dst_wave_addr_offset,
423  0);
424 
425  llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
426  dst_wave_buffer_resource,
427  dst_thread_addr_offset,
428  dst_wave_addr_offset + sizeof(float),
429  0);
430  }
431  else if constexpr(N == 4)
432  {
433  vector_type<float, 4> tmp{src_thread_data};
434 
435  llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
436  dst_wave_buffer_resource,
437  dst_thread_addr_offset,
438  dst_wave_addr_offset,
439  0);
440 
441  llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
442  dst_wave_buffer_resource,
443  dst_thread_addr_offset,
444  dst_wave_addr_offset + sizeof(float),
445  0);
446 
447  llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<2>{}],
448  dst_wave_buffer_resource,
449  dst_thread_addr_offset,
450  dst_wave_addr_offset + 2 * sizeof(float),
451  0);
452 
453  llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<3>{}],
454  dst_wave_buffer_resource,
455  dst_thread_addr_offset,
456  dst_wave_addr_offset + 3 * sizeof(float),
457  0);
458  }
459  }
460  else if constexpr(is_same<T, half_t>::value)
461  {
462  if constexpr(N == 2)
463  {
465  dst_wave_buffer_resource,
466  dst_thread_addr_offset,
467  dst_wave_addr_offset,
468  0);
469  }
470  else if constexpr(N == 4)
471  {
472  vector_type<half_t, 4> tmp{src_thread_data};
473 
474  static_for<0, 2, 1>{}([&](auto i) {
476  dst_wave_buffer_resource,
477  dst_thread_addr_offset,
478  dst_wave_addr_offset + i * sizeof(half2_t),
479  0);
480  });
481  }
482  else if constexpr(N == 8)
483  {
484  vector_type<half_t, 8> tmp{src_thread_data};
485 
486  static_for<0, 4, 1>{}([&](auto i) {
488  dst_wave_buffer_resource,
489  dst_thread_addr_offset,
490  dst_wave_addr_offset + i * sizeof(half2_t),
491  0);
492  });
493  }
494  }
495  else if constexpr(is_same<T, int32_t>::value)
496  {
497  if constexpr(N == 1)
498  {
500  dst_wave_buffer_resource,
501  dst_thread_addr_offset,
502  dst_wave_addr_offset,
503  0);
504  }
505  else if constexpr(N == 2)
506  {
507  vector_type<int32_t, 2> tmp{src_thread_data};
508 
509  llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
510  dst_wave_buffer_resource,
511  dst_thread_addr_offset,
512  dst_wave_addr_offset,
513  0);
514 
515  llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
516  dst_wave_buffer_resource,
517  dst_thread_addr_offset,
518  dst_wave_addr_offset + sizeof(int32_t),
519  0);
520  }
521  else if constexpr(N == 4)
522  {
523  vector_type<int32_t, 4> tmp{src_thread_data};
524 
525  llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
526  dst_wave_buffer_resource,
527  dst_thread_addr_offset,
528  dst_wave_addr_offset,
529  0);
530 
531  llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
532  dst_wave_buffer_resource,
533  dst_thread_addr_offset,
534  dst_wave_addr_offset + sizeof(int32_t),
535  0);
536 
537  llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<2>{}],
538  dst_wave_buffer_resource,
539  dst_thread_addr_offset,
540  dst_wave_addr_offset + 2 * sizeof(int32_t),
541  0);
542 
543  llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<3>{}],
544  dst_wave_buffer_resource,
545  dst_thread_addr_offset,
546  dst_wave_addr_offset + 3 * sizeof(int32_t),
547  0);
548  }
549  }
550 }
551 
552 template <typename T, index_t N>
553 __device__ void amd_buffer_atomic_max_impl(const typename vector_type<T, N>::type src_thread_data,
554  int32x4_t dst_wave_buffer_resource,
555  index_t dst_thread_addr_offset,
556  index_t dst_wave_addr_offset)
557 {
558  static_assert((is_same<T, double>::value && (N == 1 || N == 2 || N == 4)),
559  "wrong! not implemented");
560  if constexpr(is_same<T, double>::value)
561  {
562  if constexpr(N == 1)
563  {
565  dst_wave_buffer_resource,
566  dst_thread_addr_offset,
567  dst_wave_addr_offset,
568  0);
569  }
570  else if constexpr(N == 2)
571  {
572  vector_type<double, 2> tmp{src_thread_data};
573 
574  llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<0>{}],
575  dst_wave_buffer_resource,
576  dst_thread_addr_offset,
577  dst_wave_addr_offset,
578  0);
579 
580  llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<1>{}],
581  dst_wave_buffer_resource,
582  dst_thread_addr_offset,
583  dst_wave_addr_offset + sizeof(double),
584  0);
585  }
586  else if constexpr(N == 4)
587  {
588  vector_type<double, 4> tmp{src_thread_data};
589 
590  llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<0>{}],
591  dst_wave_buffer_resource,
592  dst_thread_addr_offset,
593  dst_wave_addr_offset,
594  0);
595 
596  llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<1>{}],
597  dst_wave_buffer_resource,
598  dst_thread_addr_offset,
599  dst_wave_addr_offset + sizeof(double),
600  0);
601 
602  llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<2>{}],
603  dst_wave_buffer_resource,
604  dst_thread_addr_offset,
605  dst_wave_addr_offset + 2 * sizeof(double),
606  0);
607 
608  llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<3>{}],
609  dst_wave_buffer_resource,
610  dst_thread_addr_offset,
611  dst_wave_addr_offset + 3 * sizeof(double),
612  0);
613  }
614  }
615 }
616 
617 // buffer_load requires:
618 // 1) p_src_wave must point to global memory space
619 // 2) p_src_wave must be a wavewise pointer.
620 // It is user's responsibility to make sure that is true.
621 template <typename T,
622  index_t N,
624 __device__ typename vector_type_maker<T, N>::type::type
626  index_t src_thread_element_offset,
627  bool src_thread_element_valid,
628  index_t src_element_space_size)
629 {
630  const __amdgpu_buffer_rsrc_t src_wave_buffer_resource =
631  make_wave_buffer_resource_new(p_src_wave, src_element_space_size);
632 
633  index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
634 
635  using vector_t = typename vector_type_maker<T, N>::type::type;
636  using scalar_t = typename scalar_type<vector_t>::type;
637 
638  constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
639 
640 #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
641  uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
642  return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
643  src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
644 
645 #else
646 
647  vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
648  src_wave_buffer_resource, src_thread_addr_offset, 0)};
649  return src_thread_element_valid ? tmp : vector_t(0);
650 #endif
651 }
652 
653 // buffer_load requires:
654 // 1) p_src_wave must point to global memory space
655 // 2) p_src_wave must be a wavewise pointer.
656 // It is user's responsibility to make sure that is true.
657 template <typename T,
658  index_t N,
660 __device__ typename vector_type_maker<T, N>::type::type
662  index_t src_thread_element_offset,
663  bool src_thread_element_valid,
664  index_t src_element_space_size,
665  T customized_value)
666 {
667  const __amdgpu_buffer_rsrc_t src_wave_buffer_resource =
668  make_wave_buffer_resource_new(p_src_wave, src_element_space_size);
669 
670  index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
671 
672  using vector_t = typename vector_type_maker<T, N>::type::type;
673  using scalar_t = typename scalar_type<vector_t>::type;
674 
675  constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
676 
677  vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
678  src_wave_buffer_resource, src_thread_addr_offset, 0)};
679 
680  return src_thread_element_valid ? tmp : vector_t(customized_value);
681 }
682 
683 // buffer_store requires:
684 // 1) p_dst_wave must point to global memory
685 // 2) p_dst_wave must be a wavewise pointer.
686 // It is user's responsibility to make sure that is true.
687 template <typename T,
688  index_t N,
690 __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
691  T* p_dst_wave,
692  const index_t dst_thread_element_offset,
693  const bool dst_thread_element_valid,
694  const index_t dst_element_space_size)
695 {
696  const __amdgpu_buffer_rsrc_t dst_wave_buffer_resource =
697  make_wave_buffer_resource_new(p_dst_wave, dst_element_space_size);
698 
699  index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
700 
701  using vector_t = typename vector_type_maker<T, N>::type::type;
702  using scalar_t = typename scalar_type<vector_t>::type;
703  constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
704 
705 #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
706  uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
707  amd_buffer_store_impl<scalar_t, vector_size, coherence>(
708  src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
709 #else
710  if(dst_thread_element_valid)
711  {
712  amd_buffer_store_impl<scalar_t, vector_size, coherence>(
713  src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
714  }
715 #endif
716 }
717 
718 // buffer_atomic_add requires:
719 // 1) p_dst_wave must point to global memory
720 // 2) p_dst_wave must be a wavewise pointer.
721 // It is user's responsibility to make sure that is true.
722 template <typename T, index_t N>
723 __device__ void
724 amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thread_data,
725  T* p_dst_wave,
726  const index_t dst_thread_element_offset,
727  const bool dst_thread_element_valid,
728  const index_t dst_element_space_size)
729 {
730  const int32x4_t dst_wave_buffer_resource =
731  make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
732 
733  index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
734 
735  using vector_t = typename vector_type_maker<T, N>::type::type;
736  using scalar_t = typename scalar_type<vector_t>::type;
737  constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
738 
739  if constexpr(is_same<T, bhalf_t>::value)
740  {
741  if(dst_thread_element_valid)
742  {
743  amd_global_atomic_add_impl<scalar_t, vector_size>(
744  src_thread_data, p_dst_wave + dst_thread_element_offset);
745  }
746  }
747  else
748  {
749 #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
750  uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
751 
752  amd_buffer_atomic_add_impl<scalar_t, vector_size>(
753  src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
754 #else
755  if(dst_thread_element_valid)
756  {
757  amd_buffer_atomic_add_impl<scalar_t, vector_size>(
758  src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
759  }
760 #endif
761  }
762 }
763 
764 // buffer_atomic_max requires:
765 // 1) p_dst_wave must point to global memory
766 // 2) p_dst_wave must be a wavewise pointer.
767 // It is user's responsibility to make sure that is true.
768 template <typename T, index_t N>
769 __device__ void
770 amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thread_data,
771  T* p_dst_wave,
772  const index_t dst_thread_element_offset,
773  const bool dst_thread_element_valid,
774  const index_t dst_element_space_size)
775 {
776  const int32x4_t dst_wave_buffer_resource =
777  make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
778 
779  index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
780 
781  using vector_t = typename vector_type_maker<T, N>::type::type;
782  using scalar_t = typename scalar_type<vector_t>::type;
783  constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
784 
785 #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
786  uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
787 
788  amd_buffer_atomic_max_impl<scalar_t, vector_size>(
789  src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
790 #else
791  if(dst_thread_element_valid)
792  {
793  amd_buffer_atomic_max_impl<scalar_t, vector_size>(
794  src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
795  }
796 #endif
797 }
798 
799 // Direct loads from global to LDS.
800 __device__ void
802  __attribute__((address_space(3))) uint32_t* lds_ptr,
803  index_t size,
804  index_t voffset,
805  index_t soffset,
806  index_t offset,
807  index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
808 
809 #ifndef __HIPCC_RTC__
810 template <typename T, index_t NumElemsPerThread>
811 __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
812  const index_t global_offset,
813  T* lds_base_ptr,
814  const index_t lds_offset,
815  const bool is_valid,
816  const index_t src_element_space_size)
817 {
818  // Direct loads require that each thread reads and writes a multiple of DWORDs (4 bytes).
819  // For gfx950: supports 1, 3, or 4 DWORDs per thread
820  // For gfx942: supports exactly 1 DWORD per thread
821  constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
822 #if defined(__gfx950__)
823  constexpr auto dword_bytes = 4;
824  static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 ||
825  bytes_per_thread == dword_bytes * 4);
826 #elif defined(__gfx942__)
827  constexpr auto dword_bytes = 4;
828  static_assert(bytes_per_thread == dword_bytes);
829 #endif
830 
831  const int32x4_t src_resource =
832  make_wave_buffer_resource(global_base_ptr, src_element_space_size);
833  const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
834 
835 #if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
836  T* lds_ptr = lds_base_ptr + lds_offset;
837 #ifndef CK_CODE_GEN_RTC
838  auto const lds_ptr_sgpr =
839  __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
840 #else
841  auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<size_t>(lds_ptr)));
842 #endif
843  asm volatile("s_mov_b32 m0, %0; \n\t"
844  "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
845  "v"(global_offset_bytes),
846  "s"(src_resource)
847  : "memory");
848 #else
849  // LDS pointer must be attributed with the LDS address space.
850  __attribute__((address_space(3))) uint32_t* lds_ptr =
851 #ifndef CK_CODE_GEN_RTC
852  reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
853  reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
854 #else
855  reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
856  reinterpret_cast<size_t>(lds_base_ptr + lds_offset));
857 #endif
858 
860  src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0);
861 #endif
862 }
863 #endif
864 
865 } // namespace ck
#define CK_BUFFER_RESOURCE_3RD_DWORD
Definition: ck.hpp:81
Definition: ck.hpp:270
__device__ int32x4_t make_wave_buffer_resource_with_default_range(T *p_wave)
Definition: amd_buffer_addressing.hpp:39
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__device__ void amd_buffer_store(const typename vector_type_maker< T, N >::type::type src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition: amd_buffer_addressing.hpp:871
__device__ void amd_direct_load_global_to_lds(const T *global_base_ptr, const index_t global_offset, T *lds_base_ptr, const index_t lds_offset, const bool is_valid, const index_t src_element_space_size)
Definition: amd_buffer_addressing.hpp:992
__device__ void amd_buffer_atomic_max(const typename vector_type_maker< T, N >::type::type src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition: amd_buffer_addressing.hpp:951
__device__ void amd_buffer_store_impl(const typename vector_type< T, N >::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition: amd_buffer_addressing.hpp:521
AmdBufferCoherenceEnum
Definition: amd_buffer_coherence.hpp:9
__device__ int32x4_t make_wave_buffer_resource(T *p_wave, index_t element_space_size)
Definition: amd_buffer_addressing.hpp:24
typename vector_type< int32_t, 2 >::type int32x2_t
Definition: dtype_vector.hpp:2168
__device__ void llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, uint32_t *lds_ptr, index_t size, index_t voffset, index_t soffset, index_t offset, index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds")
__device__ void amd_buffer_atomic_add_impl(const typename vector_type< T, N >::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition: amd_buffer_addressing.hpp:576
__device__ vector_type_maker< T, N >::type::type amd_buffer_load_invalid_element_return_customized_value(const T *p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size, T customized_value)
Definition: amd_buffer_addressing.hpp:842
__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(float vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32")
__device__ void amd_global_atomic_add_impl(const typename vector_type< T, N >::type src_thread_data, T *addr)
Definition: amd_buffer_addressing.hpp:548
__device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(half2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16")
__device__ vector_type< T, N >::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset)
Definition: amd_buffer_addressing.hpp:396
__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_new(T *p_wave, index_t element_space_size)
Definition: amd_buffer_addressing_builtins.hpp:54
__device__ void amd_buffer_atomic_add(const typename vector_type_maker< T, N >::type::type src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition: amd_buffer_addressing.hpp:905
__device__ double llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int32x4_t rsrc, int voffset, int soffset, int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64")
typename vector_type< half_t, 2 >::type half2_t
Definition: dtype_vector.hpp:2154
__device__ vector_type_maker< T, N >::type::type amd_buffer_load_invalid_element_return_zero(const T *p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size)
Definition: amd_buffer_addressing.hpp:806
__device__ void amd_buffer_atomic_max_impl(const typename vector_type< T, N >::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition: amd_buffer_addressing.hpp:734
__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(int32_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32")
typename vector_type< int32_t, 4 >::type int32x4_t
Definition: dtype_vector.hpp:2169
int32_t index_t
Definition: ck.hpp:301
__device__ void amd_buffer_store_impl_raw(const typename vector_type< int8_t, N >::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition: amd_buffer_addressing.hpp:423
__device__ vector_type< int8_t, N >::type amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset)
Definition: amd_buffer_addressing.hpp:292
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_with_default_range_new(T *p_wave)
Definition: amd_buffer_addressing_builtins.hpp:67
signed short int16_t
Definition: stdint.h:122
_W64 unsigned int uintptr_t
Definition: stdint.h:164
unsigned int uint32_t
Definition: stdint.h:126
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: integral_constant.hpp:20
static constexpr bool value
Definition: integral_constant.hpp:21
Definition: type.hpp:177
Definition: dtype_vector.hpp:11
int32x4_t content
Definition: amd_buffer_addressing.hpp:17
StaticallyIndexedArray< int32_t, 4 > config
Definition: amd_buffer_addressing.hpp:20
constexpr __device__ BufferResource()
Definition: amd_buffer_addressing_builtins.hpp:13
StaticallyIndexedArray< int32_t, 4 > range
Definition: amd_buffer_addressing.hpp:19
StaticallyIndexedArray< T *, 2 > address
Definition: amd_buffer_addressing.hpp:18