include/ck/utility/amd_inline_asm.hpp Source File

include/ck/utility/amd_inline_asm.hpp Source File#

Composable Kernel: include/ck/utility/amd_inline_asm.hpp Source File
amd_inline_asm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #ifndef CK_AMD_INLINE_ASM_HPP
5 #define CK_AMD_INLINE_ASM_HPP
6 
8 #include "data_type.hpp"
9 
10 // TODO: deprecate all amd_assembly_outer_product_xxx
11 
12 namespace ck {
13 
14 inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
15 {
16  int c;
17  asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(c) : "v"(a), "v"(b), "v"(d));
18  return c;
19 }
20 
22 {
23  half2_t d;
24  asm volatile("v_pk_fma_f16 %0, %1, %2, %3" : "=v"(d) : "v"(a), "v"(b), "v"(c));
25  return d;
26 }
27 
29 {
30  half2_t c;
31  asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
32  return c;
33 }
34 
35 // c0 += inner_product(a, b0)
36 // c1 += inner_product(a, b1)
37 __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
38 {
39  asm volatile("\n \
40  v_fmac_f32 %0, %2, %3 \n \
41  v_fmac_f32 %1, %2, %4 \n \
42  "
43  : "=v"(c0), "=v"(c1)
44  : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
45 }
46 
47 // c0 += inner_product(a, b0)
48 // c1 += inner_product(a, b1)
49 // c2 += inner_product(a, b2)
50 // c3 += inner_product(a, b3)
52  float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
53 {
54  asm volatile("\n \
55  v_fmac_f32 %0, %4, %5 \n \
56  v_fmac_f32 %1, %4, %6 \n \
57  v_fmac_f32 %2, %4, %7 \n \
58  v_fmac_f32 %3, %4, %8 \n \
59  "
60  : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
61  : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
62 }
63 
64 // c0 += inner_product(a, b0)
65 // c1 += inner_product(a, b1)
66 __device__ void
67 amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
68 {
69  asm volatile("\n \
70  v_dot2_f32_f16 %0, %2, %3, %0\n \
71  v_dot2_f32_f16 %1, %2, %4, %1\n \
72  "
73  : "=v"(c0), "=v"(c1)
74  : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
75 }
76 
77 // c0 += inner_product(a, b0)
78 // c1 += inner_product(a, b1)
79 __device__ void
80 amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
81 {
82  // TODO remove pointer casting
83  const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
84  const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
85  const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
86 
87  // do dot2 two times
88  asm volatile("\n \
89  v_dot2_f32_f16 %0, %2, %4, %0\n \
90  v_dot2_f32_f16 %1, %2, %6, %1\n \
91  v_dot2_f32_f16 %0, %3, %5, %0\n \
92  v_dot2_f32_f16 %1, %3, %7, %1\n \
93  "
94  : "=v"(c0), "=v"(c1)
95  : "v"(p_a_half2[0]),
96  "v"(p_a_half2[1]),
97  "v"(p_b0_half2[0]),
98  "v"(p_b0_half2[1]),
99  "v"(p_b1_half2[0]),
100  "v"(p_b1_half2[1]),
101  "0"(c0),
102  "1"(c1));
103 }
104 
105 // c0 += inner_product(a, b0)
106 // c1 += inner_product(a, b1)
107 // c2 += inner_product(a, b2)
108 // c3 += inner_product(a, b3)
110  half2_t b0,
111  half2_t b1,
112  half2_t b2,
113  half2_t b3,
114  float& c0,
115  float& c1,
116  float& c2,
117  float& c3)
118 {
119  asm volatile("\n \
120  v_dot2_f32_f16 %0, %4, %5, %0\n \
121  v_dot2_f32_f16 %1, %4, %6, %1\n \
122  v_dot2_f32_f16 %2, %4, %7, %2\n \
123  v_dot2_f32_f16 %3, %4, %8, %3\n \
124  "
125  : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
126  : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
127 }
128 
129 // c0 += inner_product(a, b0)
130 // c1 += inner_product(a, b1)
131 // c2 += inner_product(a, b2)
132 // c3 += inner_product(a, b3)
133 __device__ void amd_assembly_outer_product_1x4(half4_t a,
134  half4_t b0,
135  half4_t b1,
136  half4_t b2,
137  half4_t b3,
138  float& c0,
139  float& c1,
140  float& c2,
141  float& c3)
142 {
143  // TODO remove pointer casting
144  const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
145  const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
146  const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
147  const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
148  const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
149 
150  // do dot2 two times
151  asm volatile("\n \
152  v_dot2_f32_f16 %0, %4, %6, %0\n \
153  v_dot2_f32_f16 %1, %4, %8, %1\n \
154  v_dot2_f32_f16 %2, %4, %10, %2\n \
155  v_dot2_f32_f16 %3, %4, %12, %3\n \
156  v_dot2_f32_f16 %0, %5, %7, %0\n \
157  v_dot2_f32_f16 %1, %5, %9, %1\n \
158  v_dot2_f32_f16 %2, %5, %11, %2\n \
159  v_dot2_f32_f16 %3, %5, %13, %3\n \
160  "
161  : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
162  : "v"(p_a_half2[0]),
163  "v"(p_a_half2[1]),
164  "v"(p_b0_half2[0]),
165  "v"(p_b0_half2[1]),
166  "v"(p_b1_half2[0]),
167  "v"(p_b1_half2[1]),
168  "v"(p_b2_half2[0]),
169  "v"(p_b2_half2[1]),
170  "v"(p_b3_half2[0]),
171  "v"(p_b3_half2[1]),
172  "0"(c0),
173  "1"(c1),
174  "2"(c2),
175  "3"(c3));
176 }
177 
178 __device__ void amd_assembly_outer_product_1x4(half8_t a,
179  half8_t b0,
180  half8_t b1,
181  half8_t b2,
182  half8_t b3,
183  float& c0,
184  float& c1,
185  float& c2,
186  float& c3)
187 {
188 
189  // TODO remove pointer casting
190  const half4_t* p_a_half4 = c_style_pointer_cast<const half4_t*>(&a);
191  const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
192  const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
193  const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
194  const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
195 
197  p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
198 
200  p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
201 }
202 
203 __device__ void amd_assembly_outer_product_1x4(half16_t a,
204  half16_t b0,
205  half16_t b1,
206  half16_t b2,
207  half16_t b3,
208  float& c0,
209  float& c1,
210  float& c2,
211  float& c3)
212 {
213  // TODO remove pointer casting
214  const half8_t* p_a_half8 = c_style_pointer_cast<const half8_t*>(&a);
215  const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
216  const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
217  const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
218  const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
219 
221  p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
222 
224  p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
225 }
226 
227 // c0 += inner_product(a, b0)
228 // c1 += inner_product(a, b1)
229 __device__ void
230 amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
231 {
232 #if 1
233  asm volatile("\n \
234  v_dot4_i32_i8 %0, %2, %3, %0\n \
235  v_dot4_i32_i8 %1, %2, %4, %1\n \
236  "
237  : "=v"(c0), "=v"(c1)
238  : "v"(bit_cast<int32_t>(a)),
239  "v"(bit_cast<int32_t>(b0)),
240  "v"(bit_cast<int32_t>(b1)),
241  "0"(c0),
242  "1"(c1));
243 #else
244  c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
245  c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
246 #endif
247 }
248 
249 // c0 += inner_product(a, b0)
250 // c1 += inner_product(a, b1)
251 // c2 += inner_product(a, b2)
252 // c3 += inner_product(a, b3)
254  int8x4_t b0,
255  int8x4_t b1,
256  int8x4_t b2,
257  int8x4_t b3,
258  int32_t& c0,
259  int32_t& c1,
260  int32_t& c2,
261  int32_t& c3)
262 {
263 #if 1
264  asm volatile("\n \
265  v_dot4_i32_i8 %0, %4, %5, %0\n \
266  v_dot4_i32_i8 %1, %4, %6, %1\n \
267  v_dot4_i32_i8 %2, %4, %7, %2\n \
268  v_dot4_i32_i8 %3, %4, %8, %3\n \
269  "
270  : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
271  : "v"(bit_cast<int32_t>(a)),
272  "v"(bit_cast<int32_t>(b0)),
273  "v"(bit_cast<int32_t>(b1)),
274  "v"(bit_cast<int32_t>(b2)),
275  "v"(bit_cast<int32_t>(b3)),
276  "0"(c0),
277  "1"(c1),
278  "2"(c2),
279  "3"(c3));
280 #else
281  c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
282  c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
283  c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
284  c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
285 #endif
286 }
287 
288 __device__ void amd_assembly_outer_product_1x4(int8x8_t a,
289  int8x8_t b0,
290  int8x8_t b1,
291  int8x8_t b2,
292  int8x8_t b3,
293  int32_t& c0,
294  int32_t& c1,
295  int32_t& c2,
296  int32_t& c3)
297 {
298  constexpr auto I0 = Number<0>{};
299  constexpr auto I1 = Number<1>{};
300 
301  amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
302  vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
303  vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
304  vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
305  vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
306  c0,
307  c1,
308  c2,
309  c3);
310 
311  amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
312  vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
313  vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
314  vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
315  vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
316  c0,
317  c1,
318  c2,
319  c3);
320 }
321 
322 __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
323  int8x16_t b0,
324  int8x16_t b1,
325  int8x16_t b2,
326  int8x16_t b3,
327  int32_t& c0,
328  int32_t& c1,
329  int32_t& c2,
330  int32_t& c3)
331 
332 {
333  constexpr auto I0 = Number<0>{};
334  constexpr auto I1 = Number<1>{};
335  constexpr auto I2 = Number<2>{};
336  constexpr auto I3 = Number<3>{};
337 
338  amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
339  vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
340  vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
341  vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
342  vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
343  c0,
344  c1,
345  c2,
346  c3);
347 
348  amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
349  vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
350  vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
351  vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
352  vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
353  c0,
354  c1,
355  c2,
356  c3);
357 
358  amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
359  vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
360  vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
361  vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
362  vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
363  c0,
364  c1,
365  c2,
366  c3);
367 
368  amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
369  vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
370  vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
371  vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
372  vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
373  c0,
374  c1,
375  c2,
376  c3);
377 }
378 
379 } // namespace ck
380 #endif
Definition: ck.hpp:264
__device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
Definition: amd_inline_asm.hpp:28
__device__ void amd_assembly_outer_product_1x4(float a, float b0, float b1, float b2, float b3, float &c0, float &c1, float &c2, float &c3)
Definition: amd_inline_asm.hpp:51
__device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
Definition: amd_inline_asm.hpp:21
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: data_type.hpp:2515
typename vector_type< half_t, 4 >::type half4_t
Definition: data_type.hpp:2490
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float &c0, float &c1)
Definition: amd_inline_asm.hpp:37
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: data_type.hpp:2516
typename vector_type< half_t, 2 >::type half2_t
Definition: data_type.hpp:2489
__device__ int amd_assembly_and_or_b32(int a, int b, int d)
Definition: amd_inline_asm.hpp:14
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: data_type.hpp:2514
typename vector_type< half_t, 16 >::type half16_t
Definition: data_type.hpp:2492
typename vector_type< half_t, 8 >::type half8_t
Definition: data_type.hpp:2491