/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_wmma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_wmma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_wmma.hpp Source File
amd_wmma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #ifndef CK_AMD_WMMA_HPP
5 #define CK_AMD_WMMA_HPP
6 
8 #include "data_type.hpp"
9 // TODO: Add arch limitation
10 namespace ck {
11 
12 #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
13  defined(__gfx1103__) || defined(__gfx11_generic__)
14 #define __gfx11__
15 #endif
16 
17 #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
18 #define __gfx12__
19 #endif
20 
21 /********************************WAVE32 MODE***********************************************/
22 
23 // src: fp16, dst: fp32
24 template <index_t MPerWave, index_t NPerWave>
26 
27 template <>
29 {
30  template <class FloatC>
31  __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
32  {
33  // * Inline assembly need to elimate the duplicated data load, compiler won't help you
34  // delete them.
35  // amd_assembly_wmma_f32_16x16x16_f16_w32(
36  // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
37 #if defined(__gfx11__)
38  reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
39  reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
40 #else
41  ignore = reg_a;
42  ignore = reg_b;
43  ignore = reg_c;
44 #endif
45  }
46 };
47 
48 // src: bf16, dst: fp32
49 template <index_t MPerWave, index_t NPerWave>
51 
52 template <>
54 {
55  template <class FloatC>
56  __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
57  {
58 #if defined(__gfx11__)
59  reg_c.template AsType<float8_t>()(Number<0>{}) =
60  __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
61  reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
62 #else
63  ignore = reg_a;
64  ignore = reg_b;
65  ignore = reg_c;
66 #endif
67  }
68 };
69 
70 // src: fp16, dst: fp16
71 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
73 
74 template <index_t Opsel>
75 struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
76 {
77  template <class FloatC>
78  __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
79  {
80  // opsel usage
81  // false: D0.[0:15] = result
82  // true : D0.[16:31]= result
83 #if defined(__gfx11__)
84  reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
85  reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
86 #else
87  ignore = reg_a;
88  ignore = reg_b;
89  ignore = reg_c;
90 #endif
91  }
92 };
93 
94 // src: bf16, dst: bf16
95 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
97 
98 template <index_t Opsel>
100 {
101  template <class FloatC>
102  __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
103  {
104  // opsel usage
105  // false: D0.[0:15] = result
106  // true : D0.[16:31]= result
107 #if defined(__gfx11__)
108  reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
109  __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
110  reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
111 #else
112  ignore = reg_a;
113  ignore = reg_b;
114  ignore = reg_c;
115 #endif
116  }
117 };
118 
119 // src: iu8, dst: i32
120 template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
122 
123 template <bool neg_a, bool neg_b, bool clamp>
124 struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
125 {
126  template <class FloatC>
127  __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
128  {
129 #if defined(__gfx11__)
130  reg_c.template AsType<int32x8_t>()(Number<0>{}) =
131  __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
132  neg_a,
133  bit_cast<int32x4_t>(reg_a),
134  neg_b,
135  bit_cast<int32x4_t>(reg_b),
136  reg_c.template AsType<int32x8_t>()[Number<0>{}],
137  clamp);
138 #else
139  ignore = reg_a;
140  ignore = reg_b;
141  ignore = reg_c;
142 #endif
143  }
144 };
145 
146 /********************************WAVE64 MODE***********************************************/
147 
148 template <index_t MPerWave, index_t NPerWave>
150 
151 template <>
153 {
154  template <class FloatC>
155  __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
156  {
157 #if defined(__gfx11__)
158  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
159  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
160 #else
161  ignore = reg_a;
162  ignore = reg_b;
163  ignore = reg_c;
164 #endif
165  }
166 };
167 
168 // src: bf16, dst: fp32
169 template <index_t MPerWave, index_t NPerWave>
171 
172 template <>
174 {
175  template <class FloatC>
176  __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
177  {
178 #if defined(__gfx11__)
179  reg_c.template AsType<float4_t>()(Number<0>{}) =
180  __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
181  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
182 #else
183  ignore = reg_a;
184  ignore = reg_b;
185  ignore = reg_c;
186 #endif
187  }
188 };
189 
190 // src: fp16, dst: fp16
191 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
193 
194 template <index_t Opsel>
195 struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
196 {
197  template <class FloatC>
198  __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
199  {
200  // opsel usage
201  // false: D0.[0:15] = result
202  // true : D0.[16:31]= result
203 #if defined(__gfx11__)
204  reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
205  reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
206 #else
207  ignore = reg_a;
208  ignore = reg_b;
209  ignore = reg_c;
210 #endif
211  }
212 };
213 
214 // src: bf16, dst: bf16
215 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
217 
218 template <index_t Opsel>
220 {
221  template <class FloatC>
222  __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
223  {
224  // opsel usage
225  // false: D0.[0:15] = result
226  // true : D0.[16:31]= result
227 #if defined(__gfx11__)
228  reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
229  __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
230  reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
231 #else
232  ignore = reg_a;
233  ignore = reg_b;
234  ignore = reg_c;
235 #endif
236  }
237 };
238 
239 // src: iu8, dst: i32
240 template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
242 
243 template <bool neg_a, bool neg_b, bool clamp>
244 struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
245 {
246  template <class FloatC>
247  __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
248  {
249 #if defined(__gfx11__)
250  reg_c.template AsType<int32x4_t>()(Number<0>{}) =
251  __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
252  neg_a,
253  bit_cast<int32x4_t>(reg_a),
254  neg_b,
255  bit_cast<int32x4_t>(reg_b),
256  reg_c.template AsType<int32x4_t>()[Number<0>{}],
257  clamp);
258 #else
259  ignore = reg_a;
260  ignore = reg_b;
261  ignore = reg_c;
262 #endif
263  }
264 };
265 
266 // gfx12
267 /********************************WAVE32 MODE***********************************************/
268 
269 // src: fp16, dst: fp32
270 template <index_t MPerWave, index_t NPerWave>
272 
273 template <>
275 {
276  template <class FloatC>
277  __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
278  {
279  // * Inline assembly need to elimate the duplicated data load, compiler won't help you
280  // delete them.
281  // amd_assembly_wmma_f32_16x16x16_f16_w32(
282  // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
283 #if defined(__gfx12__)
284  reg_c.template AsType<float8_t>()(Number<0>{}) =
285  __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(
286  reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
287 #else
288  ignore = reg_a;
289  ignore = reg_b;
290  ignore = reg_c;
291 #endif
292  }
293 };
294 
295 // src: bf16, dst: fp32
296 template <index_t MPerWave, index_t NPerWave>
298 
299 template <>
301 {
302  template <class FloatC>
303  __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
304  {
305 #if defined(__gfx12__)
306  reg_c.template AsType<float8_t>()(Number<0>{}) =
307  __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(
308  reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
309 #else
310  ignore = reg_a;
311  ignore = reg_b;
312  ignore = reg_c;
313 #endif
314  }
315 };
316 
317 // src: iu8, dst: i32
318 template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
320 
321 template <bool neg_a, bool neg_b, bool clamp>
322 struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp>
323 {
324  template <class FloatC>
325  __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
326  {
327 #if defined(__gfx12__)
328  reg_c.template AsType<int32x8_t>()(Number<0>{}) =
329  __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
330  neg_a,
331  bit_cast<int32x2_t>(reg_a),
332  neg_b,
333  bit_cast<int32x2_t>(reg_b),
334  reg_c.template AsType<int32x8_t>()[Number<0>{}],
335  clamp);
336 #else
337  ignore = reg_a;
338  ignore = reg_b;
339  ignore = reg_c;
340 #endif
341  }
342 };
343 
344 // src: f8, f8, dst: fp32
345 template <index_t MPerWave, index_t NPerWave>
347 
348 template <>
350 {
351  template <class FloatC>
352  __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
353  {
354 #if defined(__gfx12__)
355  reg_c.template AsType<float8_t>()(Number<0>{}) =
356  __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
357  bit_cast<int32x2_t>(reg_a),
358  bit_cast<int32x2_t>(reg_b),
359  reg_c.template AsType<float8_t>()[Number<0>{}]);
360 #else
361  ignore = reg_a;
362  ignore = reg_b;
363  ignore = reg_c;
364 #endif
365  }
366 };
367 
368 // src: f8, bf8, dst: fp32
369 template <index_t MPerWave, index_t NPerWave>
371 
372 template <>
374 {
375  template <class FloatC>
376  __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
377  {
378 #if defined(__gfx12__)
379  reg_c.template AsType<float8_t>()(Number<0>{}) =
380  __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
381  bit_cast<int32x2_t>(reg_a),
382  bit_cast<int32x2_t>(reg_b),
383  reg_c.template AsType<float8_t>()[Number<0>{}]);
384 #else
385  ignore = reg_a;
386  ignore = reg_b;
387  ignore = reg_c;
388 #endif
389  }
390 };
391 
392 // src: bf8, f8, dst: fp32
393 template <index_t MPerWave, index_t NPerWave>
395 
396 template <>
398 {
399  template <class FloatC>
400  __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
401  {
402 #if defined(__gfx12__)
403  reg_c.template AsType<float8_t>()(Number<0>{}) =
404  __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
405  bit_cast<int32x2_t>(reg_a),
406  bit_cast<int32x2_t>(reg_b),
407  reg_c.template AsType<float8_t>()[Number<0>{}]);
408 #else
409  ignore = reg_a;
410  ignore = reg_b;
411  ignore = reg_c;
412 #endif
413  }
414 };
415 
416 // src: bf8, bf8, dst: fp32
417 template <index_t MPerWave, index_t NPerWave>
419 
420 template <>
422 {
423  template <class FloatC>
424  __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
425  {
426 #if defined(__gfx12__)
427  reg_c.template AsType<float8_t>()(Number<0>{}) =
428  __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
429  bit_cast<int32x2_t>(reg_a),
430  bit_cast<int32x2_t>(reg_b),
431  reg_c.template AsType<float8_t>()[Number<0>{}]);
432 #else
433  ignore = reg_a;
434  ignore = reg_b;
435  ignore = reg_c;
436 #endif
437  }
438 };
439 
440 } // namespace ck
441 #endif
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
bf8_t bf8x8_t
Definition: vector_type.hpp:227
Definition: ck.hpp:267
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition: dtype_vector.hpp:2148
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2164
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2165
typename vector_type< bhalf_t, 16 >::type bhalf16_t
Definition: dtype_vector.hpp:2149
typename vector_type< half_t, 16 >::type half16_t
Definition: dtype_vector.hpp:2142
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2141
Definition: integral_constant.hpp:20
static __device__ void Run(const bhalf16_t &reg_a, const bhalf16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:102
Definition: amd_wmma.hpp:96
static __device__ void Run(const bhalf16_t &reg_a, const bhalf16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:222
Definition: amd_wmma.hpp:216
static __device__ void Run(const half16_t &reg_a, const half16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:78
Definition: amd_wmma.hpp:72
static __device__ void Run(const half16_t &reg_a, const half16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:198
Definition: amd_wmma.hpp:192
static __device__ void Run(const bhalf16_t &reg_a, const bhalf16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:56
static __device__ void Run(const bhalf8_t &reg_a, const bhalf8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:303
Definition: amd_wmma.hpp:50
static __device__ void Run(const bhalf16_t &reg_a, const bhalf16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:176
Definition: amd_wmma.hpp:170
static __device__ void Run(const bf8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:424
static __device__ void Run(const bf8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:400
static __device__ void Run(const half16_t &reg_a, const half16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:31
static __device__ void Run(const half8_t &reg_a, const half8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:277
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
static __device__ void Run(const half16_t &reg_a, const half16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:155
Definition: amd_wmma.hpp:149
static __device__ void Run(const f8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:376
static __device__ void Run(const f8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:352
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:127
static __device__ void Run(const int8x8_t &reg_a, const int8x8_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:325
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition: amd_wmma.hpp:247
Definition: amd_wmma.hpp:241