/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_inline_asm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_inline_asm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_inline_asm.hpp Source File
amd_inline_asm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #ifndef CK_AMD_INLINE_ASM_HPP
5 #define CK_AMD_INLINE_ASM_HPP
6 
8 #include "dtype_vector.hpp"
9 
10 // TODO: deprecate all amd_assembly_outer_product_xxx
11 
12 namespace ck {
13 
14 inline __device__ int amd_assembly_and_b32(int a, int b)
15 {
16  int c;
17  asm volatile("v_and_b32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
18  return c;
19 }
20 
21 inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
22 {
23  int c;
24  asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(c) : "v"(a), "v"(b), "v"(d));
25  return c;
26 }
27 
29 {
30  half2_t d;
31  asm volatile("v_pk_fma_f16 %0, %1, %2, %3" : "=v"(d) : "v"(a), "v"(b), "v"(c));
32  return d;
33 }
34 
36 {
37  half2_t c;
38  asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
39  return c;
40 }
41 
42 inline __device__ float amd_assemble_cvt_f32_i4(int b)
43 {
44  float a;
45  asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(a) : "v"(b));
46  return a;
47 }
48 
49 inline __device__ f8x4_t amd_assembly_cvt_f8_to_f32(float b0, float b1, float b2, float b3)
50 {
51  f8x4_t a;
52  asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2\n"
53  "v_cvt_pk_fp8_f32 %0, %3, %4, op_sel:[0, 0, 1]\n"
54  : "=v"(a)
55  : "v"(b0), "v"(b1), "v"(b2), "v"(b3));
56  return a;
57 }
58 
59 inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
60 {
61  uint32_t i4x8 = static_cast<uint32_t>(a);
62  uint32_t fp8x4_0;
63  uint32_t fp8x4_1;
64  float tmp_0, tmp_1, tmp_2;
65 
66  asm volatile("v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
67  "v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_2\n"
68  "v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1]\n"
69  "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
70  "v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_3\n"
71  "v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1]\n"
72  "v_lshrrev_b32 %[v_tmp_2], 4, %[v_src]\n"
73  "v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2]\n"
74  "v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_2\n"
75  "v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n"
76  "v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2], src0_sel:BYTE_1\n"
77  "v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_3\n"
78  "v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n"
79  : [v_tmp_0] "+v"(tmp_0),
80  [v_tmp_1] "+v"(tmp_1),
81  [v_tmp_2] "+v"(tmp_2),
82  [v_dst_0] "+v"(fp8x4_0),
83  [v_dst_1] "+v"(fp8x4_1),
84  [v_src] "+v"(i4x8)
85  :);
86 
87  return bit_cast<f8x8_t>(((static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0));
88 }
89 
90 // c0 += inner_product(a, b0)
91 // c1 += inner_product(a, b1)
92 __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
93 {
94  asm volatile("\n \
95  v_fmac_f32 %0, %2, %3 \n \
96  v_fmac_f32 %1, %2, %4 \n \
97  "
98  : "=v"(c0), "=v"(c1)
99  : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
100 }
101 
102 // c0 += inner_product(a, b0)
103 // c1 += inner_product(a, b1)
104 // c2 += inner_product(a, b2)
105 // c3 += inner_product(a, b3)
107  float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
108 {
109  asm volatile("\n \
110  v_fmac_f32 %0, %4, %5 \n \
111  v_fmac_f32 %1, %4, %6 \n \
112  v_fmac_f32 %2, %4, %7 \n \
113  v_fmac_f32 %3, %4, %8 \n \
114  "
115  : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
116  : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
117 }
118 
119 // c0 += inner_product(a, b0)
120 // c1 += inner_product(a, b1)
121 __device__ void
123 {
124  asm volatile("\n \
125  v_dot2_f32_f16 %0, %2, %3, %0\n \
126  v_dot2_f32_f16 %1, %2, %4, %1\n \
127  "
128  : "=v"(c0), "=v"(c1)
129  : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
130 }
131 
132 // c0 += inner_product(a, b0)
133 // c1 += inner_product(a, b1)
134 __device__ void
135 amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
136 {
137  // TODO remove pointer casting
138  const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
139  const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
140  const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
141 
142  // do dot2 two times
143  asm volatile("\n \
144  v_dot2_f32_f16 %0, %2, %4, %0\n \
145  v_dot2_f32_f16 %1, %2, %6, %1\n \
146  v_dot2_f32_f16 %0, %3, %5, %0\n \
147  v_dot2_f32_f16 %1, %3, %7, %1\n \
148  "
149  : "=v"(c0), "=v"(c1)
150  : "v"(p_a_half2[0]),
151  "v"(p_a_half2[1]),
152  "v"(p_b0_half2[0]),
153  "v"(p_b0_half2[1]),
154  "v"(p_b1_half2[0]),
155  "v"(p_b1_half2[1]),
156  "0"(c0),
157  "1"(c1));
158 }
159 
160 // c0 += inner_product(a, b0)
161 // c1 += inner_product(a, b1)
162 // c2 += inner_product(a, b2)
163 // c3 += inner_product(a, b3)
165  half2_t b0,
166  half2_t b1,
167  half2_t b2,
168  half2_t b3,
169  float& c0,
170  float& c1,
171  float& c2,
172  float& c3)
173 {
174  asm volatile("\n \
175  v_dot2_f32_f16 %0, %4, %5, %0\n \
176  v_dot2_f32_f16 %1, %4, %6, %1\n \
177  v_dot2_f32_f16 %2, %4, %7, %2\n \
178  v_dot2_f32_f16 %3, %4, %8, %3\n \
179  "
180  : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
181  : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
182 }
183 
184 // c0 += inner_product(a, b0)
185 // c1 += inner_product(a, b1)
186 // c2 += inner_product(a, b2)
187 // c3 += inner_product(a, b3)
189  half4_t b0,
190  half4_t b1,
191  half4_t b2,
192  half4_t b3,
193  float& c0,
194  float& c1,
195  float& c2,
196  float& c3)
197 {
198  // TODO remove pointer casting
199  const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
200  const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
201  const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
202  const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
203  const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
204 
205  // do dot2 two times
206  asm volatile("\n \
207  v_dot2_f32_f16 %0, %4, %6, %0\n \
208  v_dot2_f32_f16 %1, %4, %8, %1\n \
209  v_dot2_f32_f16 %2, %4, %10, %2\n \
210  v_dot2_f32_f16 %3, %4, %12, %3\n \
211  v_dot2_f32_f16 %0, %5, %7, %0\n \
212  v_dot2_f32_f16 %1, %5, %9, %1\n \
213  v_dot2_f32_f16 %2, %5, %11, %2\n \
214  v_dot2_f32_f16 %3, %5, %13, %3\n \
215  "
216  : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
217  : "v"(p_a_half2[0]),
218  "v"(p_a_half2[1]),
219  "v"(p_b0_half2[0]),
220  "v"(p_b0_half2[1]),
221  "v"(p_b1_half2[0]),
222  "v"(p_b1_half2[1]),
223  "v"(p_b2_half2[0]),
224  "v"(p_b2_half2[1]),
225  "v"(p_b3_half2[0]),
226  "v"(p_b3_half2[1]),
227  "0"(c0),
228  "1"(c1),
229  "2"(c2),
230  "3"(c3));
231 }
232 
234  half8_t b0,
235  half8_t b1,
236  half8_t b2,
237  half8_t b3,
238  float& c0,
239  float& c1,
240  float& c2,
241  float& c3)
242 {
243 
244  // TODO remove pointer casting
245  const half4_t* p_a_half4 = c_style_pointer_cast<const half4_t*>(&a);
246  const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
247  const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
248  const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
249  const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
250 
252  p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
253 
255  p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
256 }
257 
259  half16_t b0,
260  half16_t b1,
261  half16_t b2,
262  half16_t b3,
263  float& c0,
264  float& c1,
265  float& c2,
266  float& c3)
267 {
268  // TODO remove pointer casting
269  const half8_t* p_a_half8 = c_style_pointer_cast<const half8_t*>(&a);
270  const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
271  const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
272  const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
273  const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
274 
276  p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
277 
279  p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
280 }
281 
282 // c0 += inner_product(a, b0)
283 // c1 += inner_product(a, b1)
284 __device__ void
286 {
287 #if 1
288  asm volatile("\n \
289  v_dot4_i32_i8 %0, %2, %3, %0\n \
290  v_dot4_i32_i8 %1, %2, %4, %1\n \
291  "
292  : "=v"(c0), "=v"(c1)
293  : "v"(bit_cast<int32_t>(a)),
294  "v"(bit_cast<int32_t>(b0)),
295  "v"(bit_cast<int32_t>(b1)),
296  "0"(c0),
297  "1"(c1));
298 #else
299  c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
300  c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
301 #endif
302 }
303 
304 // c0 += inner_product(a, b0)
305 // c1 += inner_product(a, b1)
306 // c2 += inner_product(a, b2)
307 // c3 += inner_product(a, b3)
309  int8x4_t b0,
310  int8x4_t b1,
311  int8x4_t b2,
312  int8x4_t b3,
313  int32_t& c0,
314  int32_t& c1,
315  int32_t& c2,
316  int32_t& c3)
317 {
318 #if 1
319  asm volatile("\n \
320  v_dot4_i32_i8 %0, %4, %5, %0\n \
321  v_dot4_i32_i8 %1, %4, %6, %1\n \
322  v_dot4_i32_i8 %2, %4, %7, %2\n \
323  v_dot4_i32_i8 %3, %4, %8, %3\n \
324  "
325  : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
326  : "v"(bit_cast<int32_t>(a)),
327  "v"(bit_cast<int32_t>(b0)),
328  "v"(bit_cast<int32_t>(b1)),
329  "v"(bit_cast<int32_t>(b2)),
330  "v"(bit_cast<int32_t>(b3)),
331  "0"(c0),
332  "1"(c1),
333  "2"(c2),
334  "3"(c3));
335 #else
336  c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
337  c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
338  c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
339  c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
340 #endif
341 }
342 
344  int8x8_t b0,
345  int8x8_t b1,
346  int8x8_t b2,
347  int8x8_t b3,
348  int32_t& c0,
349  int32_t& c1,
350  int32_t& c2,
351  int32_t& c3)
352 {
353  constexpr auto I0 = Number<0>{};
354  constexpr auto I1 = Number<1>{};
355 
356  amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
357  vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
358  vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
359  vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
360  vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
361  c0,
362  c1,
363  c2,
364  c3);
365 
366  amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
367  vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
368  vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
369  vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
370  vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
371  c0,
372  c1,
373  c2,
374  c3);
375 }
376 
378  int8x16_t b0,
379  int8x16_t b1,
380  int8x16_t b2,
381  int8x16_t b3,
382  int32_t& c0,
383  int32_t& c1,
384  int32_t& c2,
385  int32_t& c3)
386 
387 {
388  constexpr auto I0 = Number<0>{};
389  constexpr auto I1 = Number<1>{};
390  constexpr auto I2 = Number<2>{};
391  constexpr auto I3 = Number<3>{};
392 
393  amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
394  vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
395  vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
396  vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
397  vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
398  c0,
399  c1,
400  c2,
401  c3);
402 
403  amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
404  vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
405  vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
406  vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
407  vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
408  c0,
409  c1,
410  c2,
411  c3);
412 
413  amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
414  vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
415  vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
416  vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
417  vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
418  c0,
419  c1,
420  c2,
421  c3);
422 
423  amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
424  vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
425  vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
426  vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
427  vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
428  c0,
429  c1,
430  c2,
431  c3);
432 }
433 
434 } // namespace ck
435 #endif
Definition: ck.hpp:267
__device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
Definition: amd_inline_asm.hpp:35
__device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
Definition: amd_inline_asm.hpp:59
__device__ void amd_assembly_outer_product_1x4(float a, float b0, float b1, float b2, float b3, float &c0, float &c1, float &c2, float &c3)
Definition: amd_inline_asm.hpp:106
__device__ f8x4_t amd_assembly_cvt_f8_to_f32(float b0, float b1, float b2, float b3)
Definition: amd_inline_asm.hpp:49
__device__ int amd_assembly_and_b32(int a, int b)
Definition: amd_inline_asm.hpp:14
__device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
Definition: amd_inline_asm.hpp:28
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2164
typename vector_type< half_t, 4 >::type half4_t
Definition: dtype_vector.hpp:2140
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float &c0, float &c1)
Definition: amd_inline_asm.hpp:92
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2165
typename vector_type< half_t, 2 >::type half2_t
Definition: dtype_vector.hpp:2139
__device__ int amd_assembly_and_or_b32(int a, int b, int d)
Definition: amd_inline_asm.hpp:21
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2163
typename vector_type< half_t, 16 >::type half16_t
Definition: dtype_vector.hpp:2142
__device__ float amd_assemble_cvt_f32_i4(int b)
Definition: amd_inline_asm.hpp:42
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2141
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
unsigned int uint32_t
Definition: stdint.h:126
signed int int32_t
Definition: stdint.h:123
unsigned __int64 uint64_t
Definition: stdint.h:136