/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/inner_product.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/inner_product.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/inner_product.hpp Source File
inner_product.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 #include "data_type.hpp"
6 #include "type_convert.hpp"
7 
8 namespace ck {
9 
10 template <typename TA, typename TB, typename TC>
11 __device__ void inner_product(const TA& a, const TB& b, TC& c);
12 
13 template <>
14 __device__ void inner_product<float, float, float>(const float& a, const float& b, float& c)
15 {
16 #if CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
17  asm volatile("\n \
18  v_mac_f32 %0, %1, %2 \n \
19  "
20  : "=v"(c)
21  : "v"(a), "v"(b), "0"(c));
22 #elif CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
23  asm volatile("\n \
24  v_fmac_f32 %0, %1, %2 \n \
25  "
26  : "=v"(c)
27  : "v"(a), "v"(b), "0"(c));
28 #else
29  c += a * b;
30 #endif
31 }
32 
33 template <>
34 __device__ void
36 {
37  constexpr auto I0 = Number<0>{};
38  constexpr auto I1 = Number<1>{};
39 
40  inner_product(vector_type<float, 2>{a}.AsType<float>()[I0],
41  vector_type<float, 2>{b}.AsType<float>()[I0],
42  c);
43 
44  inner_product(vector_type<float, 2>{a}.AsType<float>()[I1],
45  vector_type<float, 2>{b}.AsType<float>()[I1],
46  c);
47 }
48 
49 template <>
50 __device__ void
52 {
53  constexpr auto I0 = Number<0>{};
54  constexpr auto I1 = Number<1>{};
55  constexpr auto I2 = Number<2>{};
56  constexpr auto I3 = Number<3>{};
57 
58  inner_product(vector_type<float, 4>{a}.AsType<float>()[I0],
59  vector_type<float, 4>{b}.AsType<float>()[I0],
60  c);
61 
62  inner_product(vector_type<float, 4>{a}.AsType<float>()[I1],
63  vector_type<float, 4>{b}.AsType<float>()[I1],
64  c);
65 
66  inner_product(vector_type<float, 4>{a}.AsType<float>()[I2],
67  vector_type<float, 4>{b}.AsType<float>()[I2],
68  c);
69 
70  inner_product(vector_type<float, 4>{a}.AsType<float>()[I3],
71  vector_type<float, 4>{b}.AsType<float>()[I3],
72  c);
73 }
74 
75 template <>
76 __device__ void inner_product<bhalf_t, bhalf_t, float>(const bhalf_t& a, const bhalf_t& b, float& c)
77 {
78  inner_product(type_convert<float>(a), type_convert<float>(b), c);
79 }
80 
81 template <>
82 __device__ void inner_product<half_t, half_t, float>(const half_t& a, const half_t& b, float& c)
83 {
84  inner_product(type_convert<float>(a), type_convert<float>(b), c);
85 }
86 
87 template <>
88 __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
89 {
90 #if defined(CK_USE_AMD_V_DOT2_F32_F16)
91 #if CK_USE_AMD_V_DOT_INLINE_ASM
92  // Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47
93  // https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf
94  // ) s_nop with parameter 2 is equal to 3 x s_nop
95  asm volatile("\n \
96  v_dot2_f32_f16 %0, %1, %2, %0\n \
97  s_nop 2 \n \
98  "
99  : "=v"(c)
100  : "v"(a), "v"(b), "0"(c));
101 #else
102  c = __builtin_amdgcn_fdot2(a, b, c, false);
103 #endif
104 #else
105  const vector_type<half_t, 2> a_vector{a};
106  const vector_type<half_t, 2> b_vector{b};
107 
108  static_for<0, 2, 1>{}([&](auto i) {
109  c += type_convert<float>(a_vector.AsType<half_t>()[i]) *
110  type_convert<float>(b_vector.AsType<half_t>()[i]);
111  });
112 #endif
113 }
114 
115 template <>
116 __device__ void inner_product<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
117 {
118  constexpr auto I0 = Number<0>{};
119  constexpr auto I1 = Number<1>{};
120 
122  vector_type<half_t, 4>{b}.AsType<half2_t>()[I0],
123  c);
124 
126  vector_type<half_t, 4>{b}.AsType<half2_t>()[I1],
127  c);
128 }
129 
130 template <>
131 __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
132 {
133  constexpr auto I0 = Number<0>{};
134  constexpr auto I1 = Number<1>{};
135  constexpr auto I2 = Number<2>{};
136  constexpr auto I3 = Number<3>{};
137 
139  vector_type<half_t, 8>{b}.AsType<half2_t>()[I0],
140  c);
141 
143  vector_type<half_t, 8>{b}.AsType<half2_t>()[I1],
144  c);
145 
147  vector_type<half_t, 8>{b}.AsType<half2_t>()[I2],
148  c);
149 
151  vector_type<half_t, 8>{b}.AsType<half2_t>()[I3],
152  c);
153 }
154 
155 template <>
156 __device__ void inner_product<int8_t, int8_t, int32_t>(const int8_t& a, const int8_t& b, int32_t& c)
157 {
158  c += type_convert<int32_t>(a) * type_convert<int32_t>(b);
159 }
160 
161 template <>
162 __device__ void
164 {
165  constexpr auto I0 = Number<0>{};
166  constexpr auto I1 = Number<1>{};
167 
169  vector_type<int8_t, 2>{b}.AsType<int8_t>()[I0],
170  c);
171 
173  vector_type<int8_t, 2>{b}.AsType<int8_t>()[I1],
174  c);
175 }
176 
177 template <>
178 __device__ void
180 {
181 #if defined(CK_USE_AMD_V_DOT4_I32_I8)
182 #if CK_USE_AMD_V_DOT_INLINE_ASM
183  // Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47
184  // https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf
185  // ) s_nop with parameter 2 is equal to 3 x s_nop
186  asm volatile("\n \
187  v_dot4_i32_i8 %0, %1, %2, %0\n \
188  s_nop 2 \n \
189  "
190  : "=v"(c)
191  : "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
192 #else
193  c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
194 #endif
195 #elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11)
196  c = __builtin_amdgcn_sudot4(true, bit_cast<int32_t>(a), true, bit_cast<int32_t>(b), c, false);
197 #else
198  const vector_type<int8_t, 4> a_vector{a};
199  const vector_type<int8_t, 4> b_vector{b};
200 
201  static_for<0, 4, 1>{}([&](auto i) {
202  c += type_convert<int32_t>(a_vector.AsType<int8_t>()[i]) *
203  type_convert<int32_t>(b_vector.AsType<int8_t>()[i]);
204  });
205 #endif
206 }
207 
208 template <>
209 __device__ void
211 {
212  constexpr auto I0 = Number<0>{};
213  constexpr auto I1 = Number<1>{};
214 
216  vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
217  c);
218 
220  vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
221  c);
222 }
223 
224 template <>
225 __device__ void
227 {
228  constexpr auto I0 = Number<0>{};
229  constexpr auto I1 = Number<1>{};
230  constexpr auto I2 = Number<2>{};
231  constexpr auto I3 = Number<3>{};
232 
234  vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
235  c);
236 
238  vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
239  c);
240 
242  vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
243  c);
244 
246  vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
247  c);
248 }
249 
250 } // namespace ck
Definition: ck.hpp:267
__device__ void inner_product< half_t, half_t, float >(const half_t &a, const half_t &b, float &c)
Definition: inner_product.hpp:82
__device__ void inner_product< float2_t, float2_t, float >(const float2_t &a, const float2_t &b, float &c)
Definition: inner_product.hpp:35
typename vector_type< int8_t, 2 >::type int8x2_t
Definition: dtype_vector.hpp:2162
__device__ void inner_product< int8x2_t, int8x2_t, int32_t >(const int8x2_t &a, const int8x2_t &b, int32_t &c)
Definition: inner_product.hpp:163
typename vector_type< float, 2 >::type float2_t
Definition: dtype_vector.hpp:2131
__device__ void inner_product< half2_t, half2_t, float >(const half2_t &a, const half2_t &b, float &c)
Definition: inner_product.hpp:88
__device__ void inner_product< float4_t, float4_t, float >(const float4_t &a, const float4_t &b, float &c)
Definition: inner_product.hpp:51
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2164
typename vector_type< half_t, 4 >::type half4_t
Definition: dtype_vector.hpp:2140
__device__ void inner_product< int8x8_t, int8x8_t, int32_t >(const int8x8_t &a, const int8x8_t &b, int32_t &c)
Definition: inner_product.hpp:210
_Float16 half_t
Definition: data_type.hpp:30
ushort bhalf_t
Definition: data_type.hpp:29
__device__ void inner_product< half4_t, half4_t, float >(const half4_t &a, const half4_t &b, float &c)
Definition: inner_product.hpp:116
__device__ void inner_product< bhalf_t, bhalf_t, float >(const bhalf_t &a, const bhalf_t &b, float &c)
Definition: inner_product.hpp:76
typename vector_type< float, 4 >::type float4_t
Definition: dtype_vector.hpp:2132
__device__ void inner_product< int8x16_t, int8x16_t, int32_t >(const int8x16_t &a, const int8x16_t &b, int32_t &c)
Definition: inner_product.hpp:226
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2165
__device__ void inner_product< half8_t, half8_t, float >(const half8_t &a, const half8_t &b, float &c)
Definition: inner_product.hpp:131
typename vector_type< half_t, 2 >::type half2_t
Definition: dtype_vector.hpp:2139
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2163
__device__ void inner_product< int8_t, int8_t, int32_t >(const int8_t &a, const int8_t &b, int32_t &c)
Definition: inner_product.hpp:156
__device__ void inner_product< int8x4_t, int8x4_t, int32_t >(const int8x4_t &a, const int8x4_t &b, int32_t &c)
Definition: inner_product.hpp:179
__device__ void inner_product(const TA &a, const TB &b, TC &c)
__device__ void inner_product< float, float, float >(const float &a, const float &b, float &c)
Definition: inner_product.hpp:14
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2141
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10