/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_wmma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_wmma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_wmma.hpp Source File
amd_wmma.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #ifndef CK_AMD_WMMA_HPP
5 #define CK_AMD_WMMA_HPP
6 
8 #include "data_type.hpp"
9 // TODO: Add arch limitation
10 namespace ck {
11 
12 #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
13  defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
14  defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__)
15 #define __gfx11__
16 #endif
17 
18 #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
19 #define __gfx12__
20 #endif
21 
22 /********************************WAVE32 MODE***********************************************/
23 
24 // src: fp16, dst: fp32
25 template <index_t MPerWave, index_t NPerWave>
27 
28 template <>
30 {
31  template <class FloatC>
32  __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
33  {
34  // * Inline assembly need to elimate the duplicated data load, compiler won't help you
35  // delete them.
36  // amd_assembly_wmma_f32_16x16x16_f16_w32(
37  // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
38 #if defined(__gfx11__)
39  reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
40  reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
41 #else
42  ignore = reg_a;
43  ignore = reg_b;
44  ignore = reg_c;
45 #endif
46  }
47 };
48 
49 // src: bf16, dst: fp32
50 template <index_t MPerWave, index_t NPerWave>
52 
53 template <>
55 {
56  template <class FloatC>
57  __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
58  {
59 #if defined(__gfx11__)
60  reg_c.template AsType<float8_t>()(Number<0>{}) =
61  __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
62  reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
63 #else
64  ignore = reg_a;
65  ignore = reg_b;
66  ignore = reg_c;
67 #endif
68  }
69 };
70 
71 // src: fp16, dst: fp16
72 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
74 
75 template <index_t Opsel>
76 struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
77 {
78  template <class FloatC>
79  __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
80  {
81  // opsel usage
82  // false: D0.[0:15] = result
83  // true : D0.[16:31]= result
84 #if defined(__gfx11__)
85  reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
86  reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
87 #else
88  ignore = reg_a;
89  ignore = reg_b;
90  ignore = reg_c;
91 #endif
92  }
93 };
94 
95 // src: bf16, dst: bf16
96 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
98 
99 template <index_t Opsel>
101 {
102  template <class FloatC>
103  __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
104  {
105  // opsel usage
106  // false: D0.[0:15] = result
107  // true : D0.[16:31]= result
108 #if defined(__gfx11__)
109  reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
110  __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
111  reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
112 #else
113  ignore = reg_a;
114  ignore = reg_b;
115  ignore = reg_c;
116 #endif
117  }
118 };
119 
120 // src: iu8, dst: i32
121 template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
123 
124 template <bool neg_a, bool neg_b, bool clamp>
125 struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
126 {
127  template <class FloatC>
128  __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
129  {
130 #if defined(__gfx11__)
131  reg_c.template AsType<int32x8_t>()(Number<0>{}) =
132  __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
133  neg_a,
134  bit_cast<int32x4_t>(reg_a),
135  neg_b,
136  bit_cast<int32x4_t>(reg_b),
137  reg_c.template AsType<int32x8_t>()[Number<0>{}],
138  clamp);
139 #else
140  ignore = reg_a;
141  ignore = reg_b;
142  ignore = reg_c;
143 #endif
144  }
145 };
146 
147 /********************************WAVE64 MODE***********************************************/
148 
149 template <index_t MPerWave, index_t NPerWave>
151 
152 template <>
154 {
155  template <class FloatC>
156  __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
157  {
158 #if defined(__gfx11__)
159  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
160  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
161 #else
162  ignore = reg_a;
163  ignore = reg_b;
164  ignore = reg_c;
165 #endif
166  }
167 };
168 
169 // src: bf16, dst: fp32
170 template <index_t MPerWave, index_t NPerWave>
172 
173 template <>
175 {
176  template <class FloatC>
177  __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
178  {
179 #if defined(__gfx11__)
180  reg_c.template AsType<float4_t>()(Number<0>{}) =
181  __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
182  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
183 #else
184  ignore = reg_a;
185  ignore = reg_b;
186  ignore = reg_c;
187 #endif
188  }
189 };
190 
191 // src: fp16, dst: fp16
192 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
194 
195 template <index_t Opsel>
196 struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
197 {
198  template <class FloatC>
199  __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
200  {
201  // opsel usage
202  // false: D0.[0:15] = result
203  // true : D0.[16:31]= result
204 #if defined(__gfx11__)
205  reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
206  reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
207 #else
208  ignore = reg_a;
209  ignore = reg_b;
210  ignore = reg_c;
211 #endif
212  }
213 };
214 
215 // src: bf16, dst: bf16
216 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
218 
219 template <index_t Opsel>
221 {
222  template <class FloatC>
223  __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
224  {
225  // opsel usage
226  // false: D0.[0:15] = result
227  // true : D0.[16:31]= result
228 #if defined(__gfx11__)
229  reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
230  __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
231  reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
232 #else
233  ignore = reg_a;
234  ignore = reg_b;
235  ignore = reg_c;
236 #endif
237  }
238 };
239 
240 // src: iu8, dst: i32
241 template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
243 
244 template <bool neg_a, bool neg_b, bool clamp>
245 struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
246 {
247  template <class FloatC>
248  __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
249  {
250 #if defined(__gfx11__)
251  reg_c.template AsType<int32x4_t>()(Number<0>{}) =
252  __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
253  neg_a,
254  bit_cast<int32x4_t>(reg_a),
255  neg_b,
256  bit_cast<int32x4_t>(reg_b),
257  reg_c.template AsType<int32x4_t>()[Number<0>{}],
258  clamp);
259 #else
260  ignore = reg_a;
261  ignore = reg_b;
262  ignore = reg_c;
263 #endif
264  }
265 };
266 
267 // gfx12
268 /********************************WAVE32 MODE***********************************************/
269 
270 // src: fp16, dst: fp32
271 template <index_t MPerWave, index_t NPerWave>
273 
274 template <>
276 {
277  template <class FloatC>
278  __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
279  {
280  // * Inline assembly need to elimate the duplicated data load, compiler won't help you
281  // delete them.
282  // amd_assembly_wmma_f32_16x16x16_f16_w32(
283  // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
284 #if defined(__gfx12__)
285  reg_c.template AsType<float8_t>()(Number<0>{}) =
286  __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(
287  reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
288 #else
289  ignore = reg_a;
290  ignore = reg_b;
291  ignore = reg_c;
292 #endif
293  }
294 };
295 
296 // src: bf16, dst: fp32
297 template <index_t MPerWave, index_t NPerWave>
299 
300 template <>
302 {
303  template <class FloatC>
304  __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
305  {
306 #if defined(__gfx12__)
307  reg_c.template AsType<float8_t>()(Number<0>{}) =
308  __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(
309  reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
310 #else
311  ignore = reg_a;
312  ignore = reg_b;
313  ignore = reg_c;
314 #endif
315  }
316 };
317 
318 // src: iu8, dst: i32
319 template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
321 
322 template <bool neg_a, bool neg_b, bool clamp>
323 struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp>
324 {
325  template <class FloatC>
326  __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
327  {
328 #if defined(__gfx12__)
329  reg_c.template AsType<int32x8_t>()(Number<0>{}) =
330  __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
331  neg_a,
332  bit_cast<int32x2_t>(reg_a),
333  neg_b,
334  bit_cast<int32x2_t>(reg_b),
335  reg_c.template AsType<int32x8_t>()[Number<0>{}],
336  clamp);
337 #else
338  ignore = reg_a;
339  ignore = reg_b;
340  ignore = reg_c;
341 #endif
342  }
343 };
344 
345 // src: f8, f8, dst: fp32
346 template <index_t MPerWave, index_t NPerWave>
348 
349 template <>
351 {
352  template <class FloatC>
353  __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
354  {
355 #if defined(__gfx12__)
356  reg_c.template AsType<float8_t>()(Number<0>{}) =
357  __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
358  bit_cast<int32x2_t>(reg_a),
359  bit_cast<int32x2_t>(reg_b),
360  reg_c.template AsType<float8_t>()[Number<0>{}]);
361 #else
362  ignore = reg_a;
363  ignore = reg_b;
364  ignore = reg_c;
365 #endif
366  }
367 };
368 
369 // src: f8, bf8, dst: fp32
370 template <index_t MPerWave, index_t NPerWave>
372 
373 template <>
375 {
376  template <class FloatC>
377  __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
378  {
379 #if defined(__gfx12__)
380  reg_c.template AsType<float8_t>()(Number<0>{}) =
381  __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
382  bit_cast<int32x2_t>(reg_a),
383  bit_cast<int32x2_t>(reg_b),
384  reg_c.template AsType<float8_t>()[Number<0>{}]);
385 #else
386  ignore = reg_a;
387  ignore = reg_b;
388  ignore = reg_c;
389 #endif
390  }
391 };
392 
393 // src: bf8, f8, dst: fp32
394 template <index_t MPerWave, index_t NPerWave>
396 
397 template <>
399 {
400  template <class FloatC>
401  __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
402  {
403 #if defined(__gfx12__)
404  reg_c.template AsType<float8_t>()(Number<0>{}) =
405  __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
406  bit_cast<int32x2_t>(reg_a),
407  bit_cast<int32x2_t>(reg_b),
408  reg_c.template AsType<float8_t>()[Number<0>{}]);
409 #else
410  ignore = reg_a;
411  ignore = reg_b;
412  ignore = reg_c;
413 #endif
414  }
415 };
416 
417 // src: bf8, bf8, dst: fp32
418 template <index_t MPerWave, index_t NPerWave>
420 
421 template <>
423 {
424  template <class FloatC>
425  __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
426  {
427 #if defined(__gfx12__)
428  reg_c.template AsType<float8_t>()(Number<0>{}) =
429  __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
430  bit_cast<int32x2_t>(reg_a),
431  bit_cast<int32x2_t>(reg_b),
432  reg_c.template AsType<float8_t>()[Number<0>{}]);
433 #else
434  ignore = reg_a;
435  ignore = reg_b;
436  ignore = reg_c;
437 #endif
438  }
439 };
440 
441 } // namespace ck
442 #endif
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
bf8_t bf8x8_t
Definition: vector_type.hpp:239
Definition: ck.hpp:270
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition: dtype_vector.hpp:2163
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2179
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2180
typename vector_type< bhalf_t, 16 >::type bhalf16_t
Definition: dtype_vector.hpp:2164
typename vector_type< half_t, 16 >::type half16_t
Definition: dtype_vector.hpp:2157
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2156
Definition: integral_constant.hpp:20
static __device__ void Run(const bhalf16_t &reg_a, const bhalf16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:103
Definition: amd_wmma.hpp:97
static __device__ void Run(const bhalf16_t &reg_a, const bhalf16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:223
Definition: amd_wmma.hpp:217
static __device__ void Run(const half16_t &reg_a, const half16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:79
Definition: amd_wmma.hpp:73
static __device__ void Run(const half16_t &reg_a, const half16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:199
Definition: amd_wmma.hpp:193
static __device__ void Run(const bhalf16_t &reg_a, const bhalf16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:57
static __device__ void Run(const bhalf8_t &reg_a, const bhalf8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:304
Definition: amd_wmma.hpp:51
static __device__ void Run(const bhalf16_t &reg_a, const bhalf16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:177
Definition: amd_wmma.hpp:171
static __device__ void Run(const bf8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:425
static __device__ void Run(const bf8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:401
static __device__ void Run(const half16_t &reg_a, const half16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:32
static __device__ void Run(const half8_t &reg_a, const half8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:278
Definition: amd_wmma.hpp:272
Definition: amd_wmma.hpp:26
static __device__ void Run(const half16_t &reg_a, const half16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:156
Definition: amd_wmma.hpp:150
static __device__ void Run(const f8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:377
static __device__ void Run(const f8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:353
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:128
static __device__ void Run(const int8x8_t &reg_a, const int8x8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:326
Definition: amd_wmma.hpp:320
Definition: amd_wmma.hpp:122
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:248
Definition: amd_wmma.hpp:242