4 #ifndef CK_AMD_INLINE_ASM_HPP
5 #define CK_AMD_INLINE_ASM_HPP
17 asm volatile(
"v_and_or_b32 %0, %1, %2, %3" :
"=v"(c) :
"v"(a),
"v"(b),
"v"(d));
24 asm volatile(
"v_pk_fma_f16 %0, %1, %2, %3" :
"=v"(d) :
"v"(a),
"v"(b),
"v"(c));
31 asm volatile(
"v_pk_add_f16 %0, %1, %2" :
"=v"(c) :
"v"(a),
"v"(b));
40 v_fmac_f32 %0, %2, %3 \n \
41 v_fmac_f32 %1, %2, %4 \n \
44 :
"v"(a),
"v"(b0),
"v"(b1),
"0"(c0),
"1"(c1));
52 float a,
float b0,
float b1,
float b2,
float b3,
float& c0,
float& c1,
float& c2,
float& c3)
55 v_fmac_f32 %0, %4, %5 \n \
56 v_fmac_f32 %1, %4, %6 \n \
57 v_fmac_f32 %2, %4, %7 \n \
58 v_fmac_f32 %3, %4, %8 \n \
60 :
"=v"(c0),
"=v"(c1),
"=v"(c2),
"=v"(c3)
61 :
"v"(a),
"v"(b0),
"v"(b1),
"v"(b2),
"v"(b3),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3));
70 v_dot2_f32_f16 %0, %2, %3, %0\n \
71 v_dot2_f32_f16 %1, %2, %4, %1\n \
74 :
"v"(a),
"v"(b0),
"v"(b1),
"0"(c0),
"1"(c1));
83 const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
84 const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
85 const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
89 v_dot2_f32_f16 %0, %2, %4, %0\n \
90 v_dot2_f32_f16 %1, %2, %6, %1\n \
91 v_dot2_f32_f16 %0, %3, %5, %0\n \
92 v_dot2_f32_f16 %1, %3, %7, %1\n \
120 v_dot2_f32_f16 %0, %4, %5, %0\n \
121 v_dot2_f32_f16 %1, %4, %6, %1\n \
122 v_dot2_f32_f16 %2, %4, %7, %2\n \
123 v_dot2_f32_f16 %3, %4, %8, %3\n \
125 :
"=v"(c0),
"=v"(c1),
"=v"(c2),
"=v"(c3)
126 :
"v"(a),
"v"(b0),
"v"(b1),
"v"(b2),
"v"(b3),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3));
144 const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
145 const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
146 const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
147 const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
148 const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
152 v_dot2_f32_f16 %0, %4, %6, %0\n \
153 v_dot2_f32_f16 %1, %4, %8, %1\n \
154 v_dot2_f32_f16 %2, %4, %10, %2\n \
155 v_dot2_f32_f16 %3, %4, %12, %3\n \
156 v_dot2_f32_f16 %0, %5, %7, %0\n \
157 v_dot2_f32_f16 %1, %5, %9, %1\n \
158 v_dot2_f32_f16 %2, %5, %11, %2\n \
159 v_dot2_f32_f16 %3, %5, %13, %3\n \
161 :
"=v"(c0),
"=v"(c1),
"=v"(c2),
"=v"(c3)
190 const half4_t* p_a_half4 = c_style_pointer_cast<const half4_t*>(&a);
191 const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
192 const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
193 const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
194 const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
197 p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
200 p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
214 const half8_t* p_a_half8 = c_style_pointer_cast<const half8_t*>(&a);
215 const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
216 const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
217 const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
218 const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
221 p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
224 p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
234 v_dot4_i32_i8 %0, %2, %3, %0\n \
235 v_dot4_i32_i8 %1, %2, %4, %1\n \
238 :
"v"(bit_cast<int32_t>(a)),
239 "v"(bit_cast<int32_t>(b0)),
240 "v"(bit_cast<int32_t>(b1)),
244 c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0,
false);
245 c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1,
false);
265 v_dot4_i32_i8 %0, %4, %5, %0\n \
266 v_dot4_i32_i8 %1, %4, %6, %1\n \
267 v_dot4_i32_i8 %2, %4, %7, %2\n \
268 v_dot4_i32_i8 %3, %4, %8, %3\n \
270 :
"=v"(c0),
"=v"(c1),
"=v"(c2),
"=v"(c3)
271 :
"v"(bit_cast<int32_t>(a)),
272 "v"(bit_cast<int32_t>(b0)),
273 "v"(bit_cast<int32_t>(b1)),
274 "v"(bit_cast<int32_t>(b2)),
275 "v"(bit_cast<int32_t>(b3)),
281 c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0,
false);
282 c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1,
false);
283 c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2,
false);
284 c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3,
false);
298 constexpr
auto I0 = Number<0>{};
299 constexpr
auto I1 = Number<1>{};
302 vector_type<int8_t, 8>{b0}.AsType<
int8x4_t>()[I0],
303 vector_type<int8_t, 8>{b1}.AsType<
int8x4_t>()[I0],
304 vector_type<int8_t, 8>{b2}.AsType<
int8x4_t>()[I0],
305 vector_type<int8_t, 8>{b3}.AsType<
int8x4_t>()[I0],
312 vector_type<int8_t, 8>{b0}.AsType<
int8x4_t>()[I1],
313 vector_type<int8_t, 8>{b1}.AsType<
int8x4_t>()[I1],
314 vector_type<int8_t, 8>{b2}.AsType<
int8x4_t>()[I1],
315 vector_type<int8_t, 8>{b3}.AsType<
int8x4_t>()[I1],
333 constexpr
auto I0 = Number<0>{};
334 constexpr
auto I1 = Number<1>{};
335 constexpr
auto I2 = Number<2>{};
336 constexpr
auto I3 = Number<3>{};
339 vector_type<int8_t, 16>{b0}.AsType<
int8x4_t>()[I0],
340 vector_type<int8_t, 16>{b1}.AsType<
int8x4_t>()[I0],
341 vector_type<int8_t, 16>{b2}.AsType<
int8x4_t>()[I0],
342 vector_type<int8_t, 16>{b3}.AsType<
int8x4_t>()[I0],
349 vector_type<int8_t, 16>{b0}.AsType<
int8x4_t>()[I1],
350 vector_type<int8_t, 16>{b1}.AsType<
int8x4_t>()[I1],
351 vector_type<int8_t, 16>{b2}.AsType<
int8x4_t>()[I1],
352 vector_type<int8_t, 16>{b3}.AsType<
int8x4_t>()[I1],
359 vector_type<int8_t, 16>{b0}.AsType<
int8x4_t>()[I2],
360 vector_type<int8_t, 16>{b1}.AsType<
int8x4_t>()[I2],
361 vector_type<int8_t, 16>{b2}.AsType<
int8x4_t>()[I2],
362 vector_type<int8_t, 16>{b3}.AsType<
int8x4_t>()[I2],
369 vector_type<int8_t, 16>{b0}.AsType<
int8x4_t>()[I3],
370 vector_type<int8_t, 16>{b1}.AsType<
int8x4_t>()[I3],
371 vector_type<int8_t, 16>{b2}.AsType<
int8x4_t>()[I3],
372 vector_type<int8_t, 16>{b3}.AsType<
int8x4_t>()[I3],
__device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
Definition: amd_inline_asm.hpp:28
__device__ void amd_assembly_outer_product_1x4(float a, float b0, float b1, float b2, float b3, float &c0, float &c1, float &c2, float &c3)
Definition: amd_inline_asm.hpp:51
__device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
Definition: amd_inline_asm.hpp:21
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: data_type.hpp:2515
typename vector_type< half_t, 4 >::type half4_t
Definition: data_type.hpp:2490
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float &c0, float &c1)
Definition: amd_inline_asm.hpp:37
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: data_type.hpp:2516
typename vector_type< half_t, 2 >::type half2_t
Definition: data_type.hpp:2489
__device__ int amd_assembly_and_or_b32(int a, int b, int d)
Definition: amd_inline_asm.hpp:14
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: data_type.hpp:2514
typename vector_type< half_t, 16 >::type half16_t
Definition: data_type.hpp:2492
typename vector_type< half_t, 8 >::type half8_t
Definition: data_type.hpp:2491