4 #ifndef CK_AMD_INLINE_ASM_HPP
5 #define CK_AMD_INLINE_ASM_HPP
17 asm volatile(
"v_and_b32 %0, %1, %2" :
"=v"(c) :
"v"(
a),
"v"(b));
24 asm volatile(
"v_and_or_b32 %0, %1, %2, %3" :
"=v"(c) :
"v"(
a),
"v"(b),
"v"(d));
31 asm volatile(
"v_pk_fma_f16 %0, %1, %2, %3" :
"=v"(d) :
"v"(
a),
"v"(b),
"v"(c));
38 asm volatile(
"v_pk_add_f16 %0, %1, %2" :
"=v"(c) :
"v"(
a),
"v"(b));
45 asm volatile(
"v_cvt_off_f32_i4 %0, %1" :
"=v"(
a) :
"v"(b));
52 asm volatile(
"v_cvt_pk_fp8_f32 %0, %1, %2\n"
53 "v_cvt_pk_fp8_f32 %0, %3, %4, op_sel:[0, 0, 1]\n"
55 :
"v"(b0),
"v"(b1),
"v"(b2),
"v"(b3));
64 float tmp_0, tmp_1, tmp_2;
66 asm volatile(
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
67 "v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_2\n"
68 "v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1]\n"
69 "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
70 "v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_3\n"
71 "v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1]\n"
72 "v_lshrrev_b32 %[v_tmp_2], 4, %[v_src]\n"
73 "v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2]\n"
74 "v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_2\n"
75 "v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n"
76 "v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2], src0_sel:BYTE_1\n"
77 "v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_3\n"
78 "v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n"
79 : [v_tmp_0]
"+v"(tmp_0),
80 [v_tmp_1]
"+v"(tmp_1),
81 [v_tmp_2]
"+v"(tmp_2),
82 [v_dst_0]
"+v"(fp8x4_0),
83 [v_dst_1]
"+v"(fp8x4_1),
87 return bit_cast<f8x8_t>(((
static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0));
95 v_fmac_f32 %0, %2, %3 \n \
96 v_fmac_f32 %1, %2, %4 \n \
99 :
"v"(
a),
"v"(b0),
"v"(b1),
"0"(c0),
"1"(c1));
107 float a,
float b0,
float b1,
float b2,
float b3,
float& c0,
float& c1,
float& c2,
float& c3)
110 v_fmac_f32 %0, %4, %5 \n \
111 v_fmac_f32 %1, %4, %6 \n \
112 v_fmac_f32 %2, %4, %7 \n \
113 v_fmac_f32 %3, %4, %8 \n \
115 :
"=v"(c0),
"=v"(c1),
"=v"(c2),
"=v"(c3)
116 :
"v"(
a),
"v"(b0),
"v"(b1),
"v"(b2),
"v"(b3),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3));
125 v_dot2_f32_f16 %0, %2, %3, %0\n \
126 v_dot2_f32_f16 %1, %2, %4, %1\n \
129 :
"v"(
a),
"v"(b0),
"v"(b1),
"0"(c0),
"1"(c1));
138 const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&
a);
139 const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
140 const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
144 v_dot2_f32_f16 %0, %2, %4, %0\n \
145 v_dot2_f32_f16 %1, %2, %6, %1\n \
146 v_dot2_f32_f16 %0, %3, %5, %0\n \
147 v_dot2_f32_f16 %1, %3, %7, %1\n \
175 v_dot2_f32_f16 %0, %4, %5, %0\n \
176 v_dot2_f32_f16 %1, %4, %6, %1\n \
177 v_dot2_f32_f16 %2, %4, %7, %2\n \
178 v_dot2_f32_f16 %3, %4, %8, %3\n \
180 :
"=v"(c0),
"=v"(c1),
"=v"(c2),
"=v"(c3)
181 :
"v"(
a),
"v"(b0),
"v"(b1),
"v"(b2),
"v"(b3),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3));
199 const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&
a);
200 const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
201 const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
202 const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
203 const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
207 v_dot2_f32_f16 %0, %4, %6, %0\n \
208 v_dot2_f32_f16 %1, %4, %8, %1\n \
209 v_dot2_f32_f16 %2, %4, %10, %2\n \
210 v_dot2_f32_f16 %3, %4, %12, %3\n \
211 v_dot2_f32_f16 %0, %5, %7, %0\n \
212 v_dot2_f32_f16 %1, %5, %9, %1\n \
213 v_dot2_f32_f16 %2, %5, %11, %2\n \
214 v_dot2_f32_f16 %3, %5, %13, %3\n \
216 :
"=v"(c0),
"=v"(c1),
"=v"(c2),
"=v"(c3)
245 const half4_t* p_a_half4 = c_style_pointer_cast<const half4_t*>(&
a);
246 const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
247 const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
248 const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
249 const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
252 p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
255 p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
269 const half8_t* p_a_half8 = c_style_pointer_cast<const half8_t*>(&
a);
270 const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
271 const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
272 const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
273 const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
276 p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
279 p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
289 v_dot4_i32_i8 %0, %2, %3, %0\n \
290 v_dot4_i32_i8 %1, %2, %4, %1\n \
293 :
"v"(bit_cast<int32_t>(
a)),
294 "v"(bit_cast<int32_t>(b0)),
295 "v"(bit_cast<int32_t>(b1)),
299 c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(
a), bit_cast<int32_t>(b0), c0,
false);
300 c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(
a), bit_cast<int32_t>(b1), c1,
false);
320 v_dot4_i32_i8 %0, %4, %5, %0\n \
321 v_dot4_i32_i8 %1, %4, %6, %1\n \
322 v_dot4_i32_i8 %2, %4, %7, %2\n \
323 v_dot4_i32_i8 %3, %4, %8, %3\n \
325 :
"=v"(c0),
"=v"(c1),
"=v"(c2),
"=v"(c3)
326 :
"v"(bit_cast<int32_t>(
a)),
327 "v"(bit_cast<int32_t>(b0)),
328 "v"(bit_cast<int32_t>(b1)),
329 "v"(bit_cast<int32_t>(b2)),
330 "v"(bit_cast<int32_t>(b3)),
336 c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(
a), bit_cast<int32_t>(b0), c0,
false);
337 c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(
a), bit_cast<int32_t>(b1), c1,
false);
338 c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(
a), bit_cast<int32_t>(b2), c2,
false);
339 c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(
a), bit_cast<int32_t>(b3), c3,
false);
353 constexpr
auto I0 = Number<0>{};
354 constexpr
auto I1 = Number<1>{};
357 vector_type<int8_t, 8>{b0}.AsType<
int8x4_t>()[I0],
358 vector_type<int8_t, 8>{b1}.AsType<
int8x4_t>()[I0],
359 vector_type<int8_t, 8>{b2}.AsType<
int8x4_t>()[I0],
360 vector_type<int8_t, 8>{b3}.AsType<
int8x4_t>()[I0],
367 vector_type<int8_t, 8>{b0}.AsType<
int8x4_t>()[I1],
368 vector_type<int8_t, 8>{b1}.AsType<
int8x4_t>()[I1],
369 vector_type<int8_t, 8>{b2}.AsType<
int8x4_t>()[I1],
370 vector_type<int8_t, 8>{b3}.AsType<
int8x4_t>()[I1],
388 constexpr
auto I0 = Number<0>{};
389 constexpr
auto I1 = Number<1>{};
390 constexpr
auto I2 = Number<2>{};
391 constexpr
auto I3 = Number<3>{};
394 vector_type<int8_t, 16>{b0}.AsType<
int8x4_t>()[I0],
395 vector_type<int8_t, 16>{b1}.AsType<
int8x4_t>()[I0],
396 vector_type<int8_t, 16>{b2}.AsType<
int8x4_t>()[I0],
397 vector_type<int8_t, 16>{b3}.AsType<
int8x4_t>()[I0],
404 vector_type<int8_t, 16>{b0}.AsType<
int8x4_t>()[I1],
405 vector_type<int8_t, 16>{b1}.AsType<
int8x4_t>()[I1],
406 vector_type<int8_t, 16>{b2}.AsType<
int8x4_t>()[I1],
407 vector_type<int8_t, 16>{b3}.AsType<
int8x4_t>()[I1],
414 vector_type<int8_t, 16>{b0}.AsType<
int8x4_t>()[I2],
415 vector_type<int8_t, 16>{b1}.AsType<
int8x4_t>()[I2],
416 vector_type<int8_t, 16>{b2}.AsType<
int8x4_t>()[I2],
417 vector_type<int8_t, 16>{b3}.AsType<
int8x4_t>()[I2],
424 vector_type<int8_t, 16>{b0}.AsType<
int8x4_t>()[I3],
425 vector_type<int8_t, 16>{b1}.AsType<
int8x4_t>()[I3],
426 vector_type<int8_t, 16>{b2}.AsType<
int8x4_t>()[I3],
427 vector_type<int8_t, 16>{b3}.AsType<
int8x4_t>()[I3],
__device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
Definition: amd_inline_asm.hpp:35
__device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
Definition: amd_inline_asm.hpp:59
__device__ void amd_assembly_outer_product_1x4(float a, float b0, float b1, float b2, float b3, float &c0, float &c1, float &c2, float &c3)
Definition: amd_inline_asm.hpp:106
__device__ f8x4_t amd_assembly_cvt_f8_to_f32(float b0, float b1, float b2, float b3)
Definition: amd_inline_asm.hpp:49
__device__ int amd_assembly_and_b32(int a, int b)
Definition: amd_inline_asm.hpp:14
__device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
Definition: amd_inline_asm.hpp:28
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2164
typename vector_type< half_t, 4 >::type half4_t
Definition: dtype_vector.hpp:2140
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float &c0, float &c1)
Definition: amd_inline_asm.hpp:92
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2165
typename vector_type< half_t, 2 >::type half2_t
Definition: dtype_vector.hpp:2139
__device__ int amd_assembly_and_or_b32(int a, int b, int d)
Definition: amd_inline_asm.hpp:21
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2163
typename vector_type< half_t, 16 >::type half16_t
Definition: dtype_vector.hpp:2142
__device__ float amd_assemble_cvt_f32_i4(int b)
Definition: amd_inline_asm.hpp:42
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2141
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
unsigned int uint32_t
Definition: stdint.h:126
signed int int32_t
Definition: stdint.h:123
unsigned __int64 uint64_t
Definition: stdint.h:136