/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_xdlops.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_xdlops.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_xdlops.hpp Source File
amd_xdlops.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
6 
7 namespace ck {
8 // Define the common macro for MI300 models
9 #if defined(__gfx942__) || defined(__gfx950__)
10 #define __gfx94__
11 #endif
12 
13 // fp32
14 template <index_t MPerWave, index_t NPerWave>
16 
17 template <>
19 {
20  template <class FloatC>
21  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
22  {
23  reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
24  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
25  reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
26  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
27  }
28 };
29 
30 template <>
32 {
33  template <class FloatC>
34  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
35  {
36  reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
37  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
38  }
39 };
40 
41 template <index_t MPerWave, index_t NPerWave>
43 
44 template <>
46 {
47  template <class FloatC>
48  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
49  {
50  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
51  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
52  }
53 };
54 
55 template <index_t MPerWave, index_t NPerWave>
57 
58 template <>
60 {
61  template <class FloatC>
62  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
63  {
64  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
65  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
66  }
67 };
68 
69 template <index_t MPerWave, index_t NPerWave>
71 
72 template <>
74 {
75  template <class FloatC>
76  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
77  {
78  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
79  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
80  }
81 };
82 
83 template <index_t MPerWave, index_t NPerWave>
85 
86 template <>
88 {
89  template <class FloatC>
90  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
91  {
92  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
93  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
94  }
95 };
96 
97 template <>
99 {
100  template <class FloatC>
101  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
102  {
103  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
104  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
105  reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
106  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
107  }
108 };
109 
110 // fp16
111 template <index_t MPerWave, index_t NPerWave>
113 
114 template <>
116 {
117  template <class FloatC>
118  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
119  {
120  reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
121  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
122  reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
123  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
124  }
125 };
126 
127 template <>
129 {
130  template <class FloatC>
131  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
132  {
133  reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
134  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
135  }
136 };
137 
138 template <index_t MPerWave, index_t NPerWave>
140 
141 template <>
143 {
144  template <class FloatC>
145  __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
146  {
147 #if defined(__gfx950__)
148  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16(
149  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
150 #else
151  ignore = reg_a;
152  ignore = reg_b;
153  ignore = reg_c;
154 #endif // defined(__gfx950__)
155  }
156 };
157 
158 template <index_t MPerWave, index_t NPerWave>
160 
161 template <>
163 {
164  template <class FloatC>
165  __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
166  {
167 #if defined(__gfx950__)
168  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16(
169  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
170 #else
171  ignore = reg_a;
172  ignore = reg_b;
173  ignore = reg_c;
174 #endif // defined(__gfx950__)
175  }
176 };
177 
178 template <index_t MPerWave, index_t NPerWave>
180 
181 template <>
183 {
184  template <class FloatC>
185  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
186  {
187  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
188  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
189  }
190 };
191 
192 template <index_t MPerWave, index_t NPerWave>
194 
195 template <>
197 {
198  template <class FloatC>
199  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
200  {
201  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
202  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
203  }
204 };
205 
206 template <index_t MPerWave, index_t NPerWave>
208 
209 template <>
211 {
212  template <class FloatC>
213  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
214  {
215  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
216  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
217  }
218 };
219 
220 template <index_t MPerWave, index_t NPerWave>
222 
223 template <>
225 {
226  template <class FloatC>
227  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
228  {
229  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
230  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
231  }
232 };
233 
234 template <>
236 {
237  template <class FloatC>
238  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
239  {
240  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
241  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
242  reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
243  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
244  }
245 };
246 
247 // bfp16
248 template <index_t MPerWave, index_t NPerWave>
250 
251 template <>
253 {
254  template <class FloatC>
255  __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
256  {
257 #if defined(__gfx950__)
258  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16(
259  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
260 #else
261  ignore = reg_a;
262  ignore = reg_b;
263  ignore = reg_c;
264 #endif // defined(__gfx950__)
265  }
266 };
267 
268 template <index_t MPerWave, index_t NPerWave>
270 
271 template <>
273 {
274  template <class FloatC>
275  __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
276  {
277 #if defined(__gfx950__)
278  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
279  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
280 #else
281  ignore = reg_a;
282  ignore = reg_b;
283  ignore = reg_c;
284 #endif // defined(__gfx950__)
285  }
286 };
287 
288 template <index_t MPerWave, index_t NPerWave>
290 
291 template <>
293 {
294  template <class FloatC>
295  __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
296  {
297  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
298  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
299  }
300 };
301 
302 template <index_t MPerWave, index_t NPerWave>
304 
305 template <>
307 {
308  template <class FloatC>
309  __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
310  {
311  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
312  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
313  }
314 };
315 
316 template <index_t MPerWave, index_t NPerWave>
318 
319 template <>
321 {
322  template <class FloatC>
323  __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
324  {
325  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
326  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
327  }
328 };
329 
330 template <index_t MPerWave, index_t NPerWave>
332 
333 template <>
335 {
336  template <class FloatC>
337  __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
338  {
339  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
340  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
341  }
342 };
343 
344 template <index_t MPerWave, index_t NPerWave>
346 
347 template <>
349 {
350  template <class FloatC>
351  __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
352  {
353  reg_c.template AsType<int32x16_t>()(Number<0>{}) =
354  __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
355  bit_cast<int32_t>(reg_b),
356  reg_c.template AsType<int32x16_t>()[Number<0>{}],
357  0,
358  0,
359  0);
360  }
361 };
362 
363 template <index_t MPerWave, index_t NPerWave>
365 
366 template <>
368 {
369  template <class FloatC>
370  __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
371  {
372  reg_c.template AsType<int32x4_t>()(Number<0>{}) =
373  __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
374  bit_cast<int32_t>(reg_b),
375  reg_c.template AsType<int32x4_t>()[Number<0>{}],
376  0,
377  0,
378  0);
379  }
380 };
381 
382 template <index_t MPerWave, index_t NPerWave>
384 
385 template <>
387 {
388  template <class FloatC>
389  __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
390  {
391 #if defined(__gfx950__)
392  reg_c.template AsType<int32x16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8(
393  reg_a, reg_b, reg_c.template AsType<int32x16_t>()[Number<0>{}], 0, 0, 0);
394 #else
395  ignore = reg_a;
396  ignore = reg_b;
397  ignore = reg_c;
398 #endif // defined(__gfx950__)
399  }
400 };
401 
402 template <index_t MPerWave, index_t NPerWave>
404 
405 template <>
407 {
408  template <class FloatC>
409  __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
410  {
411 #if defined(__gfx950__)
412  reg_c.template AsType<int32x4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8(
413  reg_a, reg_b, reg_c.template AsType<int32x4_t>()[Number<0>{}], 0, 0, 0);
414 #else
415  ignore = reg_a;
416  ignore = reg_b;
417  ignore = reg_c;
418 #endif // defined(__gfx950__)
419  }
420 };
421 
422 template <index_t MPerWave, index_t NPerWave>
424 
425 template <>
427 {
428  template <class FloatC>
429  __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
430  {
431  reg_c.template AsType<int32x16_t>()(Number<0>{}) =
432  __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
433  bit_cast<int64_t>(reg_b),
434  reg_c.template AsType<int32x16_t>()[Number<0>{}],
435  0,
436  0,
437  0);
438  }
439 };
440 
441 template <index_t MPerWave, index_t NPerWave>
443 
444 template <>
446 {
447  template <class FloatC>
448  __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
449  {
450  reg_c.template AsType<int32x4_t>()(Number<0>{}) =
451  __builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast<int64_t>(reg_a),
452  bit_cast<int64_t>(reg_b),
453  reg_c.template AsType<int32x4_t>()[Number<0>{}],
454  0,
455  0,
456  0);
457  }
458 };
459 
460 template <index_t MPerWave, index_t NPerWave>
462 
463 template <>
465 {
466  template <class FloatC>
467  __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
468  {
469 #if defined(__gfx90a__) || defined(__gfx94__)
470  reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
471  reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
472 #else
473  ignore = reg_a;
474  ignore = reg_b;
475  ignore = reg_c;
476 #endif
477  }
478 };
479 
480 template <index_t MPerWave, index_t NPerWave>
482 
489 template <>
491 {
492  template <class FloatC>
493  __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
494  {
495 #if defined(__gfx950__)
496  reg_c.template AsType<float16_t>()(Number<0>{}) =
497  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
498  reg_a,
499  reg_b,
500  reg_c.template AsType<float16_t>()[Number<0>{}],
501  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
502  0, // blgp
503  0,
504  0,
505  0,
506  0);
507 #else
508  ignore = reg_a;
509  ignore = reg_b;
510  ignore = reg_c;
511 #endif
512  }
513 
514  template <class FloatC>
515  __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
516  {
517 #if defined(__gfx950__)
518  reg_c.template AsType<float16_t>()(Number<0>{}) =
519  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
520  reg_a,
521  reg_b,
522  reg_c.template AsType<float16_t>()[Number<0>{}],
523  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
524  1, // blgp
525  0,
526  0,
527  0,
528  0);
529 #else
530  ignore = reg_a;
531  ignore = reg_b;
532  ignore = reg_c;
533 #endif
534  }
535 
536  template <class FloatC>
537  __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
538  {
539 #if defined(__gfx950__)
540  reg_c.template AsType<float16_t>()(Number<0>{}) =
541  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
542  reg_a,
543  reg_b,
544  reg_c.template AsType<float16_t>()[Number<0>{}],
545  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
546  0, // blgp
547  0,
548  0,
549  0,
550  0);
551 #else
552  ignore = reg_a;
553  ignore = reg_b;
554  ignore = reg_c;
555 #endif
556  }
557 
558  template <class FloatC>
559  __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
560  {
561 #if defined(__gfx950__)
562  reg_c.template AsType<float16_t>()(Number<0>{}) =
563  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
564  reg_a,
565  reg_b,
566  reg_c.template AsType<float16_t>()[Number<0>{}],
567  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
568  1, // blgp
569  0,
570  0,
571  0,
572  0);
573 #else
574  ignore = reg_a;
575  ignore = reg_b;
576  ignore = reg_c;
577 #endif
578  }
579 
580  template <class FloatC>
581  __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
582  {
583 #if defined(__gfx950__)
584 
585  int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
586  int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
587 
588  using arg_type = int32x8_t;
589 
590  reg_c.template AsType<float16_t>()(Number<0>{}) =
591  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
592  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
593  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
594  reg_c.template AsType<float16_t>()[Number<0>{}],
595  4, // cbsz
596  4, // blgp
597  0, // OPSEL
598  0,
599  0, // OPSEL
600  0);
601 #else
602  ignore = reg_a;
603  ignore = reg_b;
604  ignore = reg_c;
605 #endif
606  }
607 
608  template <class FloatC>
609  __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
610  {
611 #if defined(__gfx950__)
612 
613  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
614  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
615 
616  using arg_type = int32x8_t;
617 
618  reg_c.template AsType<float16_t>()(Number<0>{}) =
619  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
620  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
621  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
622  reg_c.template AsType<float16_t>()[Number<0>{}],
623  2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
624  2, // blgp
625  0, // OPSEL
626  0,
627  0, // OPSEL
628  0);
629 #else
630  ignore = reg_a;
631  ignore = reg_b;
632  ignore = reg_c;
633 #endif
634  }
635 
636  template <class FloatC>
637  __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
638  {
639 #if defined(__gfx950__)
640 
641  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
642  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
643 
644  using arg_type = int32x8_t;
645 
646  reg_c.template AsType<float16_t>()(Number<0>{}) =
647  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
648  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
649  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
650  reg_c.template AsType<float16_t>()[Number<0>{}],
651  3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
652  3, // blgp
653  0, // OPSEL
654  0,
655  0, // OPSEL
656  0);
657 #else
658  ignore = reg_a;
659  ignore = reg_b;
660  ignore = reg_c;
661 #endif
662  }
663 };
664 
665 template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
667 
668 template <index_t OpselA, index_t OpselB>
669 struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32, OpselA, OpselB>
670 {
671  template <class FloatC>
672  __device__ static void Run(const f8x32_t& reg_a,
673  const int32_t& scale_a,
674  const f8x32_t& reg_b,
675  const int32_t& scale_b,
676  FloatC& reg_c)
677  {
678 #if defined(__gfx950__)
679  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
680  reg_c.template AsType<float16_t>()(Number<0>{}) =
681  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
682  reg_a,
683  reg_b,
684  reg_c.template AsType<float16_t>()[Number<0>{}],
685  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
686  0, // blgp
687  OpselA, // OPSEL
688  scale_a,
689  OpselB, // OPSEL
690  scale_b);
691  // XXX: Note on the scale_a and scale_b parameters:
692  // If compiler detects that one or both scales are constant values, it will treat that
693  // constant as F32 constant. I.e., if scale_a at some point was declared as
694  // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
695  // assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
696 
697  // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
698  // when OPSEL is set otherwise.
699 #else
700  ignore = reg_a;
701  ignore = scale_a;
702  ignore = reg_b;
703  ignore = scale_b;
704  ignore = reg_c;
705 #endif
706  }
707 
708  template <class FloatC>
709  __device__ static void Run(const bf8x32_t& reg_a,
710  const int32_t& scale_a,
711  const bf8x32_t& reg_b,
712  const int32_t& scale_b,
713  FloatC& reg_c)
714  {
715 #if defined(__gfx950__)
716  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
717  reg_c.template AsType<float16_t>()(Number<0>{}) =
718  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
719  reg_a,
720  reg_b,
721  reg_c.template AsType<float16_t>()[Number<0>{}],
722  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
723  1, // blgp
724  OpselA, // OPSEL
725  scale_a,
726  OpselB, // OPSEL
727  scale_b);
728  // XXX: Note on the scale_a and scale_b parameters:
729  // If compiler detects that one or both scales are constant values, it will treat that
730  // constant as F32 constant. I.e., if scale_a at some point was declared as
731  // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
732  // assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
733 
734  // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
735  // when OPSEL is set otherwise.
736 #else
737  ignore = reg_a;
738  ignore = scale_a;
739  ignore = reg_b;
740  ignore = scale_b;
741  ignore = reg_c;
742 #endif
743  }
744 
745  template <class FloatC>
746  __device__ static void Run(const bf8x32_t& reg_a,
747  const int32_t& scale_a,
748  const f8x32_t& reg_b,
749  const int32_t& scale_b,
750  FloatC& reg_c)
751  {
752 #if defined(__gfx950__)
753  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
754  reg_c.template AsType<float16_t>()(Number<0>{}) =
755  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
756  reg_a,
757  reg_b,
758  reg_c.template AsType<float16_t>()[Number<0>{}],
759  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
760  0, // blgp
761  OpselA, // OPSEL
762  scale_a,
763  OpselB, // OPSEL
764  scale_b);
765  // XXX: Note on the scale_a and scale_b parameters:
766  // If compiler detects that one or both scales are constant values, it will treat that
767  // constant as F32 constant. I.e., if scale_a at some point was declared as
768  // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
769  // assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
770 
771  // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
772  // when OPSEL is set otherwise.
773 #else
774  ignore = reg_a;
775  ignore = scale_a;
776  ignore = reg_b;
777  ignore = scale_b;
778  ignore = reg_c;
779 #endif
780  }
781 
782  template <class FloatC>
783  __device__ static void Run(const f6x32_t& reg_a,
784  const int32_t scale_a,
785  const f6x32_t& reg_b,
786  const int32_t scale_b,
787  FloatC& reg_c)
788  {
789 #if defined(__gfx950__)
790 
791  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
792  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
793 
794  using arg_type = int32x8_t;
795 
796  reg_c.template AsType<float16_t>()(Number<0>{}) =
797  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
798  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
799  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
800  reg_c.template AsType<float16_t>()[Number<0>{}],
801  2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
802  2, // blgp
803  OpselA, // OPSEL
804  scale_a,
805  OpselB, // OPSEL
806  scale_b);
807 #else
808  ignore = reg_a;
809  ignore = scale_a;
810  ignore = reg_b;
811  ignore = scale_b;
812  ignore = reg_c;
813 #endif
814  }
815 
816  template <class FloatC>
817  __device__ static void Run(const bf6x32_t& reg_a,
818  const int32_t scale_a,
819  const bf6x32_t& reg_b,
820  const int32_t scale_b,
821  FloatC& reg_c)
822  {
823 #if defined(__gfx950__)
824 
825  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
826  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
827 
828  using arg_type = int32x8_t;
829 
830  reg_c.template AsType<float16_t>()(Number<0>{}) =
831  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
832  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
833  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
834  reg_c.template AsType<float16_t>()[Number<0>{}],
835  3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
836  3, // blgp
837  OpselA, // OPSEL
838  scale_a,
839  OpselB, // OPSEL
840  scale_b);
841 #else
842  ignore = reg_a;
843  ignore = scale_a;
844  ignore = reg_b;
845  ignore = scale_b;
846  ignore = reg_c;
847 #endif
848  }
849 
850  template <class FloatC>
851  __device__ static void Run(const f4x32_t& reg_a,
852  const int32_t scale_a,
853  const f4x32_t& reg_b,
854  const int32_t scale_b,
855  FloatC& reg_c)
856  {
857 #if defined(__gfx950__)
858 
859  int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
860  int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
861 
862  using arg_type = int32x8_t;
863 
864  reg_c.template AsType<float16_t>()(Number<0>{}) =
865  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
866  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
867  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
868  reg_c.template AsType<float16_t>()[Number<0>{}],
869  4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
870  4, // blgp
871  OpselA, // OPSEL
872  scale_a,
873  OpselB, // OPSEL
874  scale_b);
875 #else
876  ignore = reg_a;
877  ignore = scale_a;
878  ignore = reg_b;
879  ignore = scale_b;
880  ignore = reg_c;
881 #endif
882  }
883 };
884 
885 template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
887 
888 template <index_t OpselA, index_t OpselB>
889 struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
890 {
891  template <class FloatC>
892  __device__ static void Run(const f8x32_t& reg_a,
893  const int32_t& scale_a,
894  const f8x32_t& reg_b,
895  const int32_t& scale_b,
896  FloatC& reg_c)
897  {
898 #if defined(__gfx950__)
899  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
900  reg_c.template AsType<float4_t>()(Number<0>{}) =
901  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
902  reg_a,
903  reg_b,
904  reg_c.template AsType<float4_t>()[Number<0>{}],
905  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
906  0, // blgp
907  OpselA, // OPSEL
908  scale_a,
909  OpselB, // OPSEL
910  scale_b);
911 #else
912  ignore = reg_a;
913  ignore = scale_a;
914  ignore = reg_b;
915  ignore = scale_b;
916  ignore = reg_c;
917 #endif
918  }
919 
920  template <class FloatC>
921  __device__ static void Run(const bf8x32_t& reg_a,
922  const int32_t& scale_a,
923  const bf8x32_t& reg_b,
924  const int32_t& scale_b,
925  FloatC& reg_c)
926  {
927 #if defined(__gfx950__)
928  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
929  reg_c.template AsType<float4_t>()(Number<0>{}) =
930  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
931  reg_a,
932  reg_b,
933  reg_c.template AsType<float4_t>()[Number<0>{}],
934  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
935  1, // blgp
936  OpselA, // OPSEL
937  scale_a,
938  OpselB, // OPSEL
939  scale_b);
940 #else
941  ignore = reg_a;
942  ignore = scale_a;
943  ignore = reg_b;
944  ignore = scale_b;
945  ignore = reg_c;
946 #endif
947  }
948 
949  template <class FloatC>
950  __device__ static void Run(const f8x32_t& reg_a,
951  const int32_t& scale_a,
952  const bf8x32_t& reg_b,
953  const int32_t& scale_b,
954  FloatC& reg_c)
955  {
956 #if defined(__gfx950__)
957  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
958  reg_c.template AsType<float4_t>()(Number<0>{}) =
959  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
960  reg_a,
961  reg_b,
962  reg_c.template AsType<float4_t>()[Number<0>{}],
963  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
964  1, // blgp
965  OpselA, // OPSEL
966  scale_a,
967  OpselB, // OPSEL
968  scale_b);
969 #else
970  ignore = reg_a;
971  ignore = scale_a;
972  ignore = reg_b;
973  ignore = scale_b;
974  ignore = reg_c;
975 #endif
976  }
977 
978  template <class FloatC>
979  __device__ static void Run(const bf8x32_t& reg_a,
980  const int32_t& scale_a,
981  const f8x32_t& reg_b,
982  const int32_t& scale_b,
983  FloatC& reg_c)
984  {
985 #if defined(__gfx950__)
986  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
987  reg_c.template AsType<float4_t>()(Number<0>{}) =
988  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
989  reg_a,
990  reg_b,
991  reg_c.template AsType<float4_t>()[Number<0>{}],
992  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
993  0, // blgp
994  OpselA, // OPSEL
995  scale_a,
996  OpselB, // OPSEL
997  scale_b);
998 #else
999  ignore = reg_a;
1000  ignore = scale_a;
1001  ignore = reg_b;
1002  ignore = scale_b;
1003  ignore = reg_c;
1004 #endif
1005  }
1006 
1007  template <class FloatC>
1008  __device__ static void Run(const f6x32_t& reg_a,
1009  const int32_t scale_a,
1010  const f6x32_t& reg_b,
1011  const int32_t scale_b,
1012  FloatC& reg_c)
1013  {
1014 #if defined(__gfx950__)
1015  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1016  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1017 
1018  using arg_type = int32x8_t;
1019 
1020  reg_c.template AsType<float4_t>()(Number<0>{}) =
1021  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1022  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1023  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1024  reg_c.template AsType<float4_t>()[Number<0>{}],
1025  2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1026  2, // blgp
1027  OpselA, // OPSEL
1028  scale_a,
1029  OpselB, // OPSEL
1030  scale_b);
1031 #else
1032  ignore = reg_a;
1033  ignore = scale_a;
1034  ignore = reg_b;
1035  ignore = scale_b;
1036  ignore = reg_c;
1037 #endif
1038  }
1039 
1040  template <class FloatC>
1041  __device__ static void Run(const f6x16x2_t& reg_a,
1042  const int32_t scale_a,
1043  const f6x16x2_t& reg_b,
1044  const int32_t scale_b,
1045  FloatC& reg_c)
1046  {
1047 #if defined(__gfx950__)
1048  using arg_type = int32x8_t;
1049  arg_type arg_a{
1050  static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][0]),
1051  static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][1]),
1052  static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][2]),
1053  static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][0]),
1054  static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][1]),
1055  static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][2]),
1056  0,
1057  0};
1058  arg_type arg_b{
1059  static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][0]),
1060  static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][1]),
1061  static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][2]),
1062  static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][0]),
1063  static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][1]),
1064  static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][2]),
1065  0,
1066  0};
1067 
1068  reg_c.template AsType<float4_t>()(Number<0>{}) =
1069  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1070  arg_a,
1071  arg_b,
1072  reg_c.template AsType<float4_t>()[Number<0>{}],
1073  2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1074  2, // blgp
1075  OpselA, // OPSEL
1076  scale_a,
1077  OpselB, // OPSEL
1078  scale_b);
1079 #else
1080  ignore = reg_a;
1081  ignore = scale_a;
1082  ignore = reg_b;
1083  ignore = scale_b;
1084  ignore = reg_c;
1085 #endif
1086  }
1087 
1088  template <class FloatC>
1089  __device__ static void Run(const bf6x32_t& reg_a,
1090  const int32_t scale_a,
1091  const bf6x32_t& reg_b,
1092  const int32_t scale_b,
1093  FloatC& reg_c)
1094  {
1095 #if defined(__gfx950__)
1096  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1097  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1098 
1099  using arg_type = int32x8_t;
1100 
1101  reg_c.template AsType<float4_t>()(Number<0>{}) =
1102  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1103  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1104  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1105  reg_c.template AsType<float4_t>()[Number<0>{}],
1106  3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1107  3, // blgp
1108  OpselA, // OPSEL
1109  scale_a,
1110  OpselB, // OPSEL
1111  scale_b);
1112 #else
1113  ignore = reg_a;
1114  ignore = scale_a;
1115  ignore = reg_b;
1116  ignore = scale_b;
1117  ignore = reg_c;
1118 #endif
1119  }
1120 
1121  template <class FloatC>
1122  __device__ static void Run(const bf6x16x2_t& reg_a,
1123  const int32_t scale_a,
1124  const bf6x16x2_t& reg_b,
1125  const int32_t scale_b,
1126  FloatC& reg_c)
1127  {
1128 #if defined(__gfx950__)
1129  using arg_type = int32x8_t;
1130  arg_type arg_a{
1131  static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][0]),
1132  static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][1]),
1133  static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][2]),
1134  static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][0]),
1135  static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][1]),
1136  static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][2]),
1137  0,
1138  0};
1139  arg_type arg_b{
1140  static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][0]),
1141  static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][1]),
1142  static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][2]),
1143  static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][0]),
1144  static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][1]),
1145  static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][2]),
1146  0,
1147  0};
1148 
1149  reg_c.template AsType<float4_t>()(Number<0>{}) =
1150  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1151  arg_a,
1152  arg_b,
1153  reg_c.template AsType<float4_t>()[Number<0>{}],
1154  3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1155  3, // blgp
1156  OpselA, // OPSEL
1157  scale_a,
1158  OpselB, // OPSEL
1159  scale_b);
1160 #else
1161  ignore = reg_a;
1162  ignore = scale_a;
1163  ignore = reg_b;
1164  ignore = scale_b;
1165  ignore = reg_c;
1166 #endif
1167  }
1168 
1169  template <class FloatC>
1170  __device__ static void Run(const f4x32_t& reg_a,
1171  const int32_t scale_a,
1172  const f4x32_t& reg_b,
1173  const int32_t scale_b,
1174  FloatC& reg_c)
1175  {
1176 #if defined(__gfx950__)
1177  int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
1178  int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
1179  using arg_type = int32x8_t;
1180  reg_c.template AsType<float4_t>()(Number<0>{}) =
1181  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1182  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1183  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1184  reg_c.template AsType<float4_t>()[Number<0>{}],
1185  4, // cbsz
1186  4, // blgp
1187  OpselA, // OPSEL
1188  scale_a,
1189  OpselB, // OPSEL
1190  scale_b);
1191 #else
1192  ignore = reg_a;
1193  ignore = scale_a;
1194  ignore = reg_b;
1195  ignore = scale_b;
1196  ignore = reg_c;
1197 #endif
1198  }
1199 };
1200 
1201 template <index_t MPerWave, index_t NPerWave>
1203 
1210 template <>
1212 {
1213  template <class FloatC>
1214  __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
1215  {
1216 #if defined(__gfx950__)
1217  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
1218  reg_c.template AsType<float4_t>()(Number<0>{}) =
1219  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1220  reg_a,
1221  reg_b,
1222  reg_c.template AsType<float4_t>()[Number<0>{}],
1223  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1224  0, // blgp
1225  0,
1226  0,
1227  0,
1228  0);
1229 #else
1230  ignore = reg_a;
1231  ignore = reg_b;
1232  ignore = reg_c;
1233 #endif
1234  }
1235 
1236  template <class FloatC>
1237  __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
1238  {
1239 #if defined(__gfx950__)
1240  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
1241  reg_c.template AsType<float4_t>()(Number<0>{}) =
1242  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1243  reg_a,
1244  reg_b,
1245  reg_c.template AsType<float4_t>()[Number<0>{}],
1246  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1247  1, // blgp
1248  0,
1249  0,
1250  0,
1251  0);
1252 #else
1253  ignore = reg_a;
1254  ignore = reg_b;
1255  ignore = reg_c;
1256 #endif
1257  }
1258 
1259  template <class FloatC>
1260  __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
1261  {
1262 #if defined(__gfx950__)
1263  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
1264  reg_c.template AsType<float4_t>()(Number<0>{}) =
1265  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1266  reg_a,
1267  reg_b,
1268  reg_c.template AsType<float4_t>()[Number<0>{}],
1269  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1270  0, // blgp
1271  0,
1272  0,
1273  0,
1274  0);
1275 #else
1276  ignore = reg_a;
1277  ignore = reg_b;
1278  ignore = reg_c;
1279 #endif
1280  }
1281 
1282  template <class FloatC>
1283  __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
1284  {
1285 #if defined(__gfx950__)
1286  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
1287  reg_c.template AsType<float4_t>()(Number<0>{}) =
1288  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1289  reg_a,
1290  reg_b,
1291  reg_c.template AsType<float4_t>()[Number<0>{}],
1292  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1293  1, // blgp
1294  0,
1295  0,
1296  0,
1297  0);
1298 #else
1299  ignore = reg_a;
1300  ignore = reg_b;
1301  ignore = reg_c;
1302 #endif
1303  }
1304 
1305  template <class FloatC>
1306  __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
1307  {
1308 #if defined(__gfx950__)
1309  int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
1310  int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
1311 
1312  using arg_type = int32x8_t;
1313 
1314  reg_c.template AsType<float4_t>()(Number<0>{}) =
1315  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1316  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1317  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1318  reg_c.template AsType<float4_t>()[Number<0>{}],
1319  4, // cbsz
1320  4, // blgp
1321  0, // OPSEL
1322  0,
1323  0, // OPSEL
1324  0);
1325 #else
1326  ignore = reg_a;
1327  ignore = reg_b;
1328  ignore = reg_c;
1329 #endif
1330  }
1331 
1332  template <class FloatC>
1333  __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
1334  {
1335 #if defined(__gfx950__)
1336  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1337  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1338 
1339  using arg_type = int32x8_t;
1340 
1341  reg_c.template AsType<float4_t>()(Number<0>{}) =
1342  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1343  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1344  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1345  reg_c.template AsType<float4_t>()[Number<0>{}],
1346  2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1347  2, // blgp
1348  0, // OPSEL
1349  0,
1350  0, // OPSEL
1351  0);
1352 #else
1353  ignore = reg_a;
1354  ignore = reg_b;
1355  ignore = reg_c;
1356 #endif
1357  }
1358 
1359  template <class FloatC>
1360  __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
1361  {
1362 #if defined(__gfx950__)
1363  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1364  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1365 
1366  using arg_type = int32x8_t;
1367 
1368  reg_c.template AsType<float4_t>()(Number<0>{}) =
1369  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1370  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1371  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1372  reg_c.template AsType<float4_t>()[Number<0>{}],
1373  3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1374  3, // blgp
1375  0, // OPSEL
1376  0,
1377  0, // OPSEL
1378  0);
1379 #else
1380  ignore = reg_a;
1381  ignore = reg_b;
1382  ignore = reg_c;
1383 #endif
1384  }
1385 };
1386 
1387 template <index_t MPerWave, index_t NPerWave>
1389 
1390 template <>
1392 {
1393  template <class FloatC>
1394  __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
1395  {
1396 #if defined(__gfx94__)
1397  reg_c.template AsType<float16_t>()(Number<0>{}) =
1398  __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1399  bit_cast<int64_t>(reg_a),
1400  bit_cast<int64_t>(reg_b),
1401  reg_c.template AsType<float16_t>()[Number<0>{}],
1402  0,
1403  0,
1404  0);
1405 #else
1406  vector_type<f8_t, 8> reg_a_v(reg_a);
1407  vector_type<f8_t, 8> reg_b_v(reg_b);
1408 
1409  static_for<0, 8, 1>{}([&](auto k) {
1410  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
1411  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
1412 
1413  intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
1414  });
1415 #endif
1416  }
1417 };
1418 
1419 template <index_t MPerWave, index_t NPerWave>
1421 
1422 template <>
1424 {
1425  template <class FloatC>
1426  __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
1427  {
1428 #if defined(__gfx94__)
1429  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1430  bit_cast<int64_t>(reg_a),
1431  bit_cast<int64_t>(reg_b),
1432  reg_c.template AsType<float4_t>()[Number<0>{}],
1433  0,
1434  0,
1435  0);
1436 #else
1437  vector_type<f8_t, 8> reg_a_v(reg_a);
1438  vector_type<f8_t, 8> reg_b_v(reg_b);
1439 
1440  static_for<0, 8, 1>{}([&](auto k) {
1441  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
1442  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
1443 
1444  intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
1445  });
1446 #endif
1447  }
1448 };
1449 
1450 template <index_t MPerWave, index_t NPerWave>
1452 
1453 template <>
1455 {
1456  template <class FloatC>
1457  __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
1458  {
1459 #if defined(__gfx94__)
1460  reg_c.template AsType<float16_t>()(Number<0>{}) =
1461  __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1462  bit_cast<int64_t>(reg_a),
1463  bit_cast<int64_t>(reg_b),
1464  reg_c.template AsType<float16_t>()[Number<0>{}],
1465  0,
1466  0,
1467  0);
1468 #else
1469  vector_type<bf8_t, 8> reg_a_v(reg_a);
1470  vector_type<bf8_t, 8> reg_b_v(reg_b);
1471 
1472  static_for<0, 8, 1>{}([&](auto k) {
1473  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
1474  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
1475 
1476  intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
1477  });
1478 #endif
1479  }
1480 };
1481 
1482 template <index_t MPerWave, index_t NPerWave>
1484 
1485 template <>
1487 {
1488  template <class FloatC>
1489  __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
1490  {
1491 #if defined(__gfx94__)
1492  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1493  bit_cast<int64_t>(reg_a),
1494  bit_cast<int64_t>(reg_b),
1495  reg_c.template AsType<float4_t>()[Number<0>{}],
1496  0,
1497  0,
1498  0);
1499 #else
1500  vector_type<bf8_t, 8> reg_a_v(reg_a);
1501  vector_type<bf8_t, 8> reg_b_v(reg_b);
1502 
1503  static_for<0, 8, 1>{}([&](auto k) {
1504  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
1505  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
1506 
1507  intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
1508  });
1509 #endif
1510  }
1511 };
1512 
1513 template <index_t MPerWave, index_t NPerWave>
1515 
1516 template <>
1518 {
1519  template <class FloatC>
1520  __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
1521  {
1522 #if defined(__gfx94__)
1523  reg_c.template AsType<float16_t>()(Number<0>{}) =
1524  __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1525  bit_cast<int64_t>(reg_a),
1526  bit_cast<int64_t>(reg_b),
1527  reg_c.template AsType<float16_t>()[Number<0>{}],
1528  0,
1529  0,
1530  0);
1531 #else
1532  vector_type<f8_t, 8> reg_a_v(reg_a);
1533  vector_type<bf8_t, 8> reg_b_v(reg_b);
1534 
1535  static_for<0, 8, 1>{}([&](auto k) {
1536  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
1537  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
1538 
1539  intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
1540  });
1541 #endif
1542  }
1543 };
1544 
1545 template <index_t MPerWave, index_t NPerWave>
1547 
1548 template <>
1550 {
1551  template <class FloatC>
1552  __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
1553  {
1554 #if defined(__gfx94__)
1555  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1556  bit_cast<int64_t>(reg_a),
1557  bit_cast<int64_t>(reg_b),
1558  reg_c.template AsType<float4_t>()[Number<0>{}],
1559  0,
1560  0,
1561  0);
1562 #else
1563  vector_type<f8_t, 8> reg_a_v(reg_a);
1564  vector_type<bf8_t, 8> reg_b_v(reg_b);
1565 
1566  static_for<0, 8, 1>{}([&](auto k) {
1567  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
1568  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
1569 
1570  intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
1571  });
1572 #endif
1573  }
1574 };
1575 
1576 template <index_t MPerWave, index_t NPerWave>
1578 
1579 template <>
1581 {
1582  template <class FloatC>
1583  __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
1584  {
1585 #if defined(__gfx94__)
1586  reg_c.template AsType<float16_t>()(Number<0>{}) =
1587  __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1588  bit_cast<int64_t>(reg_a),
1589  bit_cast<int64_t>(reg_b),
1590  reg_c.template AsType<float16_t>()[Number<0>{}],
1591  0,
1592  0,
1593  0);
1594 #else
1595  vector_type<bf8_t, 8> reg_a_v(reg_a);
1596  vector_type<f8_t, 8> reg_b_v(reg_b);
1597 
1598  static_for<0, 8, 1>{}([&](auto k) {
1599  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
1600  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
1601 
1602  intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
1603  });
1604 #endif
1605  }
1606 };
1607 
1608 template <index_t MPerWave, index_t NPerWave>
1610 
1611 template <>
1613 {
1614  template <class FloatC>
1615  __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
1616  {
1617 #if defined(__gfx94__)
1618  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1619  bit_cast<int64_t>(reg_a),
1620  bit_cast<int64_t>(reg_b),
1621  reg_c.template AsType<float4_t>()[Number<0>{}],
1622  0,
1623  0,
1624  0);
1625 #else
1626  vector_type<bf8_t, 8> reg_a_v(reg_a);
1627  vector_type<f8_t, 8> reg_b_v(reg_b);
1628 
1629  static_for<0, 8, 1>{}([&](auto k) {
1630  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
1631  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
1632 
1633  intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
1634  });
1635 #endif
1636  }
1637 };
1638 
1639 } // namespace ck
bf8_t bf8x32_t
Definition: vector_type.hpp:229
bf8_t bf8x8_t
Definition: vector_type.hpp:227
Definition: ck.hpp:267
typename vector_type< bf6x16_pk_t, 2 >::type bf6x16x2_t
Definition: dtype_vector.hpp:2258
typename vector_type< f6x16_pk_t, 2 >::type f6x16x2_t
Definition: dtype_vector.hpp:2253
typename vector_type< f6x32_pk_t, 1 >::type f6x32_t
Definition: dtype_vector.hpp:2254
typename vector_type< bhalf_t, 4 >::type bhalf4_t
Definition: dtype_vector.hpp:2147
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition: dtype_vector.hpp:2148
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2164
typename vector_type< half_t, 4 >::type half4_t
Definition: dtype_vector.hpp:2140
typename vector_type< bf6x32_pk_t, 1 >::type bf6x32_t
Definition: dtype_vector.hpp:2259
typename vector_type< int32_t, 8 >::type int32x8_t
Definition: dtype_vector.hpp:2156
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename vector_type< f4x2_pk_t, 16 >::type f4x32_t
Definition: dtype_vector.hpp:2248
typename vector_type< bhalf_t, 2 >::type bhalf2_t
Definition: dtype_vector.hpp:2146
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2165
typename vector_type< int32_t, 4 >::type int32x4_t
Definition: dtype_vector.hpp:2154
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2163
typename vector_type< int32_t, 6 >::type int32x6_t
Definition: dtype_vector.hpp:2155
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2141
signed int int32_t
Definition: stdint.h:123
Definition: integral_constant.hpp:20
static __device__ void Run(const bf6x32_t &reg_a, const bf6x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1360
static __device__ void Run(const f6x32_t &reg_a, const f6x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1333
static __device__ void Run(const f8x32_t &reg_a, const bf8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1283
static __device__ void Run(const bf8x32_t &reg_a, const bf8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1237
static __device__ void Run(const bf8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1260
static __device__ void Run(const f4x32_t &reg_a, const f4x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1306
static __device__ void Run(const f8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1214
Definition: amd_xdlops.hpp:1202
static __device__ void Run(const bhalf4_t &reg_a, const bhalf4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:309
Definition: amd_xdlops.hpp:303
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:199
Definition: amd_xdlops.hpp:193
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:76
Definition: amd_xdlops.hpp:70
static __device__ void Run(const bhalf8_t &reg_a, const bhalf8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:275
Definition: amd_xdlops.hpp:269
static __device__ void Run(const bf8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1489
Definition: amd_xdlops.hpp:1483
static __device__ void Run(const bf8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1615
Definition: amd_xdlops.hpp:1609
static __device__ void Run(const half8_t &reg_a, const half8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:165
Definition: amd_xdlops.hpp:159
static __device__ void Run(const f8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1552
Definition: amd_xdlops.hpp:1546
static __device__ void Run(const f8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1426
Definition: amd_xdlops.hpp:1420
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:213
Definition: amd_xdlops.hpp:207
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:62
Definition: amd_xdlops.hpp:56
static __device__ void Run(const bhalf2_t &reg_a, const bhalf2_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:337
Definition: amd_xdlops.hpp:331
static __device__ void Run(const bhalf8_t &reg_a, const bhalf8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:255
Definition: amd_xdlops.hpp:249
static __device__ void Run(const bf8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1457
Definition: amd_xdlops.hpp:1451
static __device__ void Run(const bf8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1583
Definition: amd_xdlops.hpp:1577
static __device__ void Run(const half8_t &reg_a, const half8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:145
Definition: amd_xdlops.hpp:139
static __device__ void Run(const f8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1520
Definition: amd_xdlops.hpp:1514
static __device__ void Run(const f8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1394
Definition: amd_xdlops.hpp:1388
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:34
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:21
Definition: amd_xdlops.hpp:15
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:48
Definition: amd_xdlops.hpp:42
static __device__ void Run(const bhalf2_t &reg_a, const bhalf2_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:323
Definition: amd_xdlops.hpp:317
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:131
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:118
Definition: amd_xdlops.hpp:112
static __device__ void Run(const bf8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:537
static __device__ void Run(const f8x32_t &reg_a, const bf8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:559
static __device__ void Run(const bf6x32_t &reg_a, const bf6x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:637
static __device__ void Run(const f6x32_t &reg_a, const f6x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:609
static __device__ void Run(const f8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:493
static __device__ void Run(const f4x32_t &reg_a, const f4x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:581
static __device__ void Run(const bf8x32_t &reg_a, const bf8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:515
Definition: amd_xdlops.hpp:481
static __device__ void Run(const bhalf4_t &reg_a, const bhalf4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:295
Definition: amd_xdlops.hpp:289
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:185
Definition: amd_xdlops.hpp:179
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:90
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:101
Definition: amd_xdlops.hpp:84
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:227
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:238
Definition: amd_xdlops.hpp:221
static __device__ void Run(const double &reg_a, const double &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:467
Definition: amd_xdlops.hpp:461
static __device__ void Run(const int8x4_t &reg_a, const int8x4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:370
Definition: amd_xdlops.hpp:364
static __device__ void Run(const int8x8_t &reg_a, const int8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:448
Definition: amd_xdlops.hpp:442
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:409
Definition: amd_xdlops.hpp:403
static __device__ void Run(const int8x8_t &reg_a, const int8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:429
Definition: amd_xdlops.hpp:423
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:389
Definition: amd_xdlops.hpp:383
static __device__ void Run(const int8x4_t &reg_a, const int8x4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:351
Definition: amd_xdlops.hpp:345
static __device__ void Run(const f6x16x2_t &reg_a, const int32_t scale_a, const f6x16x2_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1041
static __device__ void Run(const f4x32_t &reg_a, const int32_t scale_a, const f4x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1170
static __device__ void Run(const f6x32_t &reg_a, const int32_t scale_a, const f6x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1008
static __device__ void Run(const f8x32_t &reg_a, const int32_t &scale_a, const f8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:892
static __device__ void Run(const bf8x32_t &reg_a, const int32_t &scale_a, const f8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:979
static __device__ void Run(const f8x32_t &reg_a, const int32_t &scale_a, const bf8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:950
static __device__ void Run(const bf6x16x2_t &reg_a, const int32_t scale_a, const bf6x16x2_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1122
static __device__ void Run(const bf6x32_t &reg_a, const int32_t scale_a, const bf6x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1089
static __device__ void Run(const bf8x32_t &reg_a, const int32_t &scale_a, const bf8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:921
Definition: amd_xdlops.hpp:886
static __device__ void Run(const f8x32_t &reg_a, const int32_t &scale_a, const f8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:672
static __device__ void Run(const f6x32_t &reg_a, const int32_t scale_a, const f6x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:783
static __device__ void Run(const bf8x32_t &reg_a, const int32_t &scale_a, const bf8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:709
static __device__ void Run(const bf8x32_t &reg_a, const int32_t &scale_a, const f8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:746
static __device__ void Run(const f4x32_t &reg_a, const int32_t scale_a, const f4x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:851
static __device__ void Run(const bf6x32_t &reg_a, const int32_t scale_a, const bf6x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:817
Definition: amd_xdlops.hpp:666
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10