9 #if defined(__gfx942__) || defined(__gfx950__)
14 template <index_t MPerWave, index_t NPerWave>
20 template <
class FloatC>
21 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
23 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
24 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
25 reg_c.template AsType<float32_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
26 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<1>{}], 1, 1, 0);
33 template <
class FloatC>
34 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
36 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
37 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
41 template <index_t MPerWave, index_t NPerWave>
47 template <
class FloatC>
48 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
50 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
51 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
55 template <index_t MPerWave, index_t NPerWave>
61 template <
class FloatC>
62 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
64 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
65 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
69 template <index_t MPerWave, index_t NPerWave>
75 template <
class FloatC>
76 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
78 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
79 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 2, 0, 0);
83 template <index_t MPerWave, index_t NPerWave>
89 template <
class FloatC>
90 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
92 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
93 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
100 template <
class FloatC>
101 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
103 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
104 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
105 reg_c.template AsType<float4_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
106 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<1>{}], 4, 1, 0);
111 template <index_t MPerWave, index_t NPerWave>
117 template <
class FloatC>
120 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
121 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
122 reg_c.template AsType<float32_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
123 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<1>{}], 1, 1, 0);
130 template <
class FloatC>
133 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
134 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
138 template <index_t MPerWave, index_t NPerWave>
144 template <
class FloatC>
147 #if defined(__gfx950__)
148 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16(
149 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
158 template <index_t MPerWave, index_t NPerWave>
164 template <
class FloatC>
167 #if defined(__gfx950__)
168 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16(
169 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
178 template <index_t MPerWave, index_t NPerWave>
184 template <
class FloatC>
187 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
188 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
192 template <index_t MPerWave, index_t NPerWave>
198 template <
class FloatC>
201 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
202 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
206 template <index_t MPerWave, index_t NPerWave>
212 template <
class FloatC>
215 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
216 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 2, 0, 0);
220 template <index_t MPerWave, index_t NPerWave>
226 template <
class FloatC>
229 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
230 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
237 template <
class FloatC>
240 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
241 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
242 reg_c.template AsType<float4_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
243 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<1>{}], 4, 1, 0);
248 template <index_t MPerWave, index_t NPerWave>
254 template <
class FloatC>
257 #if defined(__gfx950__)
258 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16(
259 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
268 template <index_t MPerWave, index_t NPerWave>
274 template <
class FloatC>
277 #if defined(__gfx950__)
278 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
279 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
288 template <index_t MPerWave, index_t NPerWave>
294 template <
class FloatC>
297 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
298 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
302 template <index_t MPerWave, index_t NPerWave>
308 template <
class FloatC>
311 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
312 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
316 template <index_t MPerWave, index_t NPerWave>
322 template <
class FloatC>
325 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
326 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
330 template <index_t MPerWave, index_t NPerWave>
336 template <
class FloatC>
339 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
340 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
344 template <index_t MPerWave, index_t NPerWave>
350 template <
class FloatC>
353 reg_c.template AsType<int32x16_t>()(
Number<0>{}) =
354 __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
355 bit_cast<int32_t>(reg_b),
356 reg_c.template AsType<int32x16_t>()[
Number<0>{}],
363 template <index_t MPerWave, index_t NPerWave>
369 template <
class FloatC>
372 reg_c.template AsType<int32x4_t>()(
Number<0>{}) =
373 __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
374 bit_cast<int32_t>(reg_b),
375 reg_c.template AsType<int32x4_t>()[
Number<0>{}],
382 template <index_t MPerWave, index_t NPerWave>
388 template <
class FloatC>
391 #if defined(__gfx950__)
392 reg_c.template AsType<int32x16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8(
393 reg_a, reg_b, reg_c.template AsType<int32x16_t>()[
Number<0>{}], 0, 0, 0);
402 template <index_t MPerWave, index_t NPerWave>
408 template <
class FloatC>
411 #if defined(__gfx950__)
412 reg_c.template AsType<int32x4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8(
413 reg_a, reg_b, reg_c.template AsType<int32x4_t>()[
Number<0>{}], 0, 0, 0);
422 template <index_t MPerWave, index_t NPerWave>
428 template <
class FloatC>
431 reg_c.template AsType<int32x16_t>()(
Number<0>{}) =
432 __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
433 bit_cast<int64_t>(reg_b),
434 reg_c.template AsType<int32x16_t>()[
Number<0>{}],
441 template <index_t MPerWave, index_t NPerWave>
447 template <
class FloatC>
450 reg_c.template AsType<int32x4_t>()(
Number<0>{}) =
451 __builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast<int64_t>(reg_a),
452 bit_cast<int64_t>(reg_b),
453 reg_c.template AsType<int32x4_t>()[
Number<0>{}],
460 template <index_t MPerWave, index_t NPerWave>
466 template <
class FloatC>
467 __device__
static void Run(
const double& reg_a,
const double& reg_b, FloatC& reg_c)
469 #if defined(__gfx90a__) || defined(__gfx94__)
470 reg_c.template AsType<double4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
471 reg_a, reg_b, reg_c.template AsType<double4_t>()[
Number<0>{}], 0, 0, 0);
480 template <index_t MPerWave, index_t NPerWave>
492 template <
class FloatC>
493 __device__
static void Run(
const f8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
495 #if defined(__gfx950__)
496 reg_c.template AsType<float16_t>()(
Number<0>{}) =
497 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
500 reg_c.template AsType<float16_t>()[
Number<0>{}],
514 template <
class FloatC>
517 #if defined(__gfx950__)
518 reg_c.template AsType<float16_t>()(
Number<0>{}) =
519 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
522 reg_c.template AsType<float16_t>()[
Number<0>{}],
536 template <
class FloatC>
537 __device__
static void Run(
const bf8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
539 #if defined(__gfx950__)
540 reg_c.template AsType<float16_t>()(
Number<0>{}) =
541 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
544 reg_c.template AsType<float16_t>()[
Number<0>{}],
558 template <
class FloatC>
559 __device__
static void Run(
const f8x32_t& reg_a,
const bf8x32_t& reg_b, FloatC& reg_c)
561 #if defined(__gfx950__)
562 reg_c.template AsType<float16_t>()(
Number<0>{}) =
563 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
566 reg_c.template AsType<float16_t>()[
Number<0>{}],
580 template <
class FloatC>
583 #if defined(__gfx950__)
585 int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
586 int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
590 reg_c.template AsType<float16_t>()(
Number<0>{}) =
591 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
592 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
593 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
594 reg_c.template AsType<float16_t>()[
Number<0>{}],
608 template <
class FloatC>
611 #if defined(__gfx950__)
613 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
614 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
618 reg_c.template AsType<float16_t>()(
Number<0>{}) =
619 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
620 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
621 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
622 reg_c.template AsType<float16_t>()[
Number<0>{}],
636 template <
class FloatC>
639 #if defined(__gfx950__)
641 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
642 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
646 reg_c.template AsType<float16_t>()(
Number<0>{}) =
647 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
648 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
649 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
650 reg_c.template AsType<float16_t>()[
Number<0>{}],
665 template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
668 template <index_t OpselA, index_t OpselB>
671 template <
class FloatC>
672 __device__
static void Run(
const f8x32_t& reg_a,
674 const f8x32_t& reg_b,
678 #if defined(__gfx950__)
680 reg_c.template AsType<float16_t>()(
Number<0>{}) =
681 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
684 reg_c.template AsType<float16_t>()[
Number<0>{}],
708 template <
class FloatC>
715 #if defined(__gfx950__)
717 reg_c.template AsType<float16_t>()(
Number<0>{}) =
718 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
721 reg_c.template AsType<float16_t>()[
Number<0>{}],
745 template <
class FloatC>
748 const f8x32_t& reg_b,
752 #if defined(__gfx950__)
754 reg_c.template AsType<float16_t>()(
Number<0>{}) =
755 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
758 reg_c.template AsType<float16_t>()[
Number<0>{}],
782 template <
class FloatC>
789 #if defined(__gfx950__)
791 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
792 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
796 reg_c.template AsType<float16_t>()(
Number<0>{}) =
797 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
798 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
799 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
800 reg_c.template AsType<float16_t>()[
Number<0>{}],
816 template <
class FloatC>
823 #if defined(__gfx950__)
825 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
826 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
830 reg_c.template AsType<float16_t>()(
Number<0>{}) =
831 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
832 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
833 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
834 reg_c.template AsType<float16_t>()[
Number<0>{}],
850 template <
class FloatC>
857 #if defined(__gfx950__)
859 int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
860 int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
864 reg_c.template AsType<float16_t>()(
Number<0>{}) =
865 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
866 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
867 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
868 reg_c.template AsType<float16_t>()[
Number<0>{}],
885 template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
888 template <index_t OpselA, index_t OpselB>
891 template <
class FloatC>
892 __device__
static void Run(
const f8x32_t& reg_a,
894 const f8x32_t& reg_b,
898 #if defined(__gfx950__)
900 reg_c.template AsType<float4_t>()(
Number<0>{}) =
901 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
904 reg_c.template AsType<float4_t>()[
Number<0>{}],
920 template <
class FloatC>
927 #if defined(__gfx950__)
929 reg_c.template AsType<float4_t>()(
Number<0>{}) =
930 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
933 reg_c.template AsType<float4_t>()[
Number<0>{}],
949 template <
class FloatC>
950 __device__
static void Run(
const f8x32_t& reg_a,
956 #if defined(__gfx950__)
958 reg_c.template AsType<float4_t>()(
Number<0>{}) =
959 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
962 reg_c.template AsType<float4_t>()[
Number<0>{}],
978 template <
class FloatC>
981 const f8x32_t& reg_b,
985 #if defined(__gfx950__)
987 reg_c.template AsType<float4_t>()(
Number<0>{}) =
988 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
991 reg_c.template AsType<float4_t>()[
Number<0>{}],
1007 template <
class FloatC>
1014 #if defined(__gfx950__)
1015 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1016 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1020 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1021 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1022 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1023 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1024 reg_c.template AsType<float4_t>()[
Number<0>{}],
1040 template <
class FloatC>
1047 #if defined(__gfx950__)
1050 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][0]),
1051 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][1]),
1052 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][2]),
1053 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][0]),
1054 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][1]),
1055 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][2]),
1059 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][0]),
1060 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][1]),
1061 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][2]),
1062 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][0]),
1063 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][1]),
1064 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][2]),
1068 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1069 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1072 reg_c.template AsType<float4_t>()[
Number<0>{}],
1088 template <
class FloatC>
1095 #if defined(__gfx950__)
1096 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1097 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1101 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1102 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1103 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1104 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1105 reg_c.template AsType<float4_t>()[
Number<0>{}],
1121 template <
class FloatC>
1128 #if defined(__gfx950__)
1131 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][0]),
1132 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][1]),
1133 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][2]),
1134 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][0]),
1135 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][1]),
1136 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][2]),
1140 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][0]),
1141 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][1]),
1142 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][2]),
1143 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][0]),
1144 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][1]),
1145 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][2]),
1149 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1150 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1153 reg_c.template AsType<float4_t>()[
Number<0>{}],
1169 template <
class FloatC>
1176 #if defined(__gfx950__)
1177 int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
1178 int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
1180 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1181 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1182 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1183 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1184 reg_c.template AsType<float4_t>()[
Number<0>{}],
1201 template <index_t MPerWave, index_t NPerWave>
1213 template <
class FloatC>
1214 __device__
static void Run(
const f8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
1216 #if defined(__gfx950__)
1218 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1219 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1222 reg_c.template AsType<float4_t>()[
Number<0>{}],
1236 template <
class FloatC>
1239 #if defined(__gfx950__)
1241 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1242 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1245 reg_c.template AsType<float4_t>()[
Number<0>{}],
1259 template <
class FloatC>
1260 __device__
static void Run(
const bf8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
1262 #if defined(__gfx950__)
1264 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1265 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1268 reg_c.template AsType<float4_t>()[
Number<0>{}],
1282 template <
class FloatC>
1283 __device__
static void Run(
const f8x32_t& reg_a,
const bf8x32_t& reg_b, FloatC& reg_c)
1285 #if defined(__gfx950__)
1287 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1288 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1291 reg_c.template AsType<float4_t>()[
Number<0>{}],
1305 template <
class FloatC>
1308 #if defined(__gfx950__)
1309 int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
1310 int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
1314 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1315 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1316 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1317 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1318 reg_c.template AsType<float4_t>()[
Number<0>{}],
1332 template <
class FloatC>
1335 #if defined(__gfx950__)
1336 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1337 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1341 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1342 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1343 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1344 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1345 reg_c.template AsType<float4_t>()[
Number<0>{}],
1359 template <
class FloatC>
1362 #if defined(__gfx950__)
1363 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1364 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1368 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1369 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1370 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1371 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1372 reg_c.template AsType<float4_t>()[
Number<0>{}],
1387 template <index_t MPerWave, index_t NPerWave>
1393 template <
class FloatC>
1394 __device__
static void Run(
const f8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
1396 #if defined(__gfx94__)
1397 reg_c.template AsType<float16_t>()(
Number<0>{}) =
1398 __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1399 bit_cast<int64_t>(reg_a),
1400 bit_cast<int64_t>(reg_b),
1401 reg_c.template AsType<float16_t>()[
Number<0>{}],
1410 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[
Number<k>{}]);
1411 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[
Number<k>{}]);
1419 template <index_t MPerWave, index_t NPerWave>
1425 template <
class FloatC>
1426 __device__
static void Run(
const f8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
1428 #if defined(__gfx94__)
1429 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1430 bit_cast<int64_t>(reg_a),
1431 bit_cast<int64_t>(reg_b),
1432 reg_c.template AsType<float4_t>()[
Number<0>{}],
1441 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[
Number<k>{}]);
1442 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[
Number<k>{}]);
1450 template <index_t MPerWave, index_t NPerWave>
1456 template <
class FloatC>
1459 #if defined(__gfx94__)
1460 reg_c.template AsType<float16_t>()(
Number<0>{}) =
1461 __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1462 bit_cast<int64_t>(reg_a),
1463 bit_cast<int64_t>(reg_b),
1464 reg_c.template AsType<float16_t>()[
Number<0>{}],
1473 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[
Number<k>{}]);
1474 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[
Number<k>{}]);
1482 template <index_t MPerWave, index_t NPerWave>
1488 template <
class FloatC>
1491 #if defined(__gfx94__)
1492 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1493 bit_cast<int64_t>(reg_a),
1494 bit_cast<int64_t>(reg_b),
1495 reg_c.template AsType<float4_t>()[
Number<0>{}],
1504 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[
Number<k>{}]);
1505 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[
Number<k>{}]);
1513 template <index_t MPerWave, index_t NPerWave>
1519 template <
class FloatC>
1520 __device__
static void Run(
const f8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
1522 #if defined(__gfx94__)
1523 reg_c.template AsType<float16_t>()(
Number<0>{}) =
1524 __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1525 bit_cast<int64_t>(reg_a),
1526 bit_cast<int64_t>(reg_b),
1527 reg_c.template AsType<float16_t>()[
Number<0>{}],
1536 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[
Number<k>{}]);
1537 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[
Number<k>{}]);
1545 template <index_t MPerWave, index_t NPerWave>
1551 template <
class FloatC>
1552 __device__
static void Run(
const f8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
1554 #if defined(__gfx94__)
1555 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1556 bit_cast<int64_t>(reg_a),
1557 bit_cast<int64_t>(reg_b),
1558 reg_c.template AsType<float4_t>()[
Number<0>{}],
1567 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[
Number<k>{}]);
1568 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[
Number<k>{}]);
1576 template <index_t MPerWave, index_t NPerWave>
1582 template <
class FloatC>
1583 __device__
static void Run(
const bf8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
1585 #if defined(__gfx94__)
1586 reg_c.template AsType<float16_t>()(
Number<0>{}) =
1587 __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1588 bit_cast<int64_t>(reg_a),
1589 bit_cast<int64_t>(reg_b),
1590 reg_c.template AsType<float16_t>()[
Number<0>{}],
1599 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[
Number<k>{}]);
1600 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[
Number<k>{}]);
1608 template <index_t MPerWave, index_t NPerWave>
1614 template <
class FloatC>
1615 __device__
static void Run(
const bf8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
1617 #if defined(__gfx94__)
1618 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1619 bit_cast<int64_t>(reg_a),
1620 bit_cast<int64_t>(reg_b),
1621 reg_c.template AsType<float4_t>()[
Number<0>{}],
1630 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[
Number<k>{}]);
1631 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[
Number<k>{}]);
bf8_t bf8x32_t
Definition: vector_type.hpp:229
bf8_t bf8x8_t
Definition: vector_type.hpp:227
typename vector_type< bf6x16_pk_t, 2 >::type bf6x16x2_t
Definition: dtype_vector.hpp:2258
typename vector_type< f6x16_pk_t, 2 >::type f6x16x2_t
Definition: dtype_vector.hpp:2253
typename vector_type< f6x32_pk_t, 1 >::type f6x32_t
Definition: dtype_vector.hpp:2254
typename vector_type< bhalf_t, 4 >::type bhalf4_t
Definition: dtype_vector.hpp:2147
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition: dtype_vector.hpp:2148
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2164
typename vector_type< half_t, 4 >::type half4_t
Definition: dtype_vector.hpp:2140
typename vector_type< bf6x32_pk_t, 1 >::type bf6x32_t
Definition: dtype_vector.hpp:2259
typename vector_type< int32_t, 8 >::type int32x8_t
Definition: dtype_vector.hpp:2156
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename vector_type< f4x2_pk_t, 16 >::type f4x32_t
Definition: dtype_vector.hpp:2248
typename vector_type< bhalf_t, 2 >::type bhalf2_t
Definition: dtype_vector.hpp:2146
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2165
typename vector_type< int32_t, 4 >::type int32x4_t
Definition: dtype_vector.hpp:2154
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2163
typename vector_type< int32_t, 6 >::type int32x6_t
Definition: dtype_vector.hpp:2155
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2141
signed int int32_t
Definition: stdint.h:123
Definition: integral_constant.hpp:20
static __device__ void Run(const bf6x32_t ®_a, const bf6x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1360
static __device__ void Run(const f6x32_t ®_a, const f6x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1333
static __device__ void Run(const f8x32_t ®_a, const bf8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1283
static __device__ void Run(const bf8x32_t ®_a, const bf8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1237
static __device__ void Run(const bf8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1260
static __device__ void Run(const f4x32_t ®_a, const f4x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1306
static __device__ void Run(const f8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1214
Definition: amd_xdlops.hpp:1202
static __device__ void Run(const bhalf4_t ®_a, const bhalf4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:309
Definition: amd_xdlops.hpp:303
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:199
Definition: amd_xdlops.hpp:193
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:76
Definition: amd_xdlops.hpp:70
static __device__ void Run(const bhalf8_t ®_a, const bhalf8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:275
Definition: amd_xdlops.hpp:269
static __device__ void Run(const bf8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1489
Definition: amd_xdlops.hpp:1483
static __device__ void Run(const bf8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1615
Definition: amd_xdlops.hpp:1609
static __device__ void Run(const half8_t ®_a, const half8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:165
Definition: amd_xdlops.hpp:159
static __device__ void Run(const f8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1552
Definition: amd_xdlops.hpp:1546
static __device__ void Run(const f8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1426
Definition: amd_xdlops.hpp:1420
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:213
Definition: amd_xdlops.hpp:207
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:62
Definition: amd_xdlops.hpp:56
static __device__ void Run(const bhalf2_t ®_a, const bhalf2_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:337
Definition: amd_xdlops.hpp:331
static __device__ void Run(const bhalf8_t ®_a, const bhalf8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:255
Definition: amd_xdlops.hpp:249
static __device__ void Run(const bf8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1457
Definition: amd_xdlops.hpp:1451
static __device__ void Run(const bf8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1583
Definition: amd_xdlops.hpp:1577
static __device__ void Run(const half8_t ®_a, const half8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:145
Definition: amd_xdlops.hpp:139
static __device__ void Run(const f8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1520
Definition: amd_xdlops.hpp:1514
static __device__ void Run(const f8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1394
Definition: amd_xdlops.hpp:1388
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:34
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:21
Definition: amd_xdlops.hpp:15
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:48
Definition: amd_xdlops.hpp:42
static __device__ void Run(const bhalf2_t ®_a, const bhalf2_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:323
Definition: amd_xdlops.hpp:317
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:131
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:118
Definition: amd_xdlops.hpp:112
static __device__ void Run(const bf8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:537
static __device__ void Run(const f8x32_t ®_a, const bf8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:559
static __device__ void Run(const bf6x32_t ®_a, const bf6x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:637
static __device__ void Run(const f6x32_t ®_a, const f6x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:609
static __device__ void Run(const f8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:493
static __device__ void Run(const f4x32_t ®_a, const f4x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:581
static __device__ void Run(const bf8x32_t ®_a, const bf8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:515
Definition: amd_xdlops.hpp:481
static __device__ void Run(const bhalf4_t ®_a, const bhalf4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:295
Definition: amd_xdlops.hpp:289
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:185
Definition: amd_xdlops.hpp:179
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:90
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:101
Definition: amd_xdlops.hpp:84
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:227
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:238
Definition: amd_xdlops.hpp:221
static __device__ void Run(const double ®_a, const double ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:467
Definition: amd_xdlops.hpp:461
static __device__ void Run(const int8x4_t ®_a, const int8x4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:370
Definition: amd_xdlops.hpp:364
static __device__ void Run(const int8x8_t ®_a, const int8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:448
Definition: amd_xdlops.hpp:442
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:409
Definition: amd_xdlops.hpp:403
static __device__ void Run(const int8x8_t ®_a, const int8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:429
Definition: amd_xdlops.hpp:423
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:389
Definition: amd_xdlops.hpp:383
static __device__ void Run(const int8x4_t ®_a, const int8x4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:351
Definition: amd_xdlops.hpp:345
static __device__ void Run(const f6x16x2_t ®_a, const int32_t scale_a, const f6x16x2_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1041
static __device__ void Run(const f4x32_t ®_a, const int32_t scale_a, const f4x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1170
static __device__ void Run(const f6x32_t ®_a, const int32_t scale_a, const f6x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1008
static __device__ void Run(const f8x32_t ®_a, const int32_t &scale_a, const f8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:892
static __device__ void Run(const bf8x32_t ®_a, const int32_t &scale_a, const f8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:979
static __device__ void Run(const f8x32_t ®_a, const int32_t &scale_a, const bf8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:950
static __device__ void Run(const bf6x16x2_t ®_a, const int32_t scale_a, const bf6x16x2_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1122
static __device__ void Run(const bf6x32_t ®_a, const int32_t scale_a, const bf6x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1089
static __device__ void Run(const bf8x32_t ®_a, const int32_t &scale_a, const bf8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:921
Definition: amd_xdlops.hpp:886
static __device__ void Run(const f8x32_t ®_a, const int32_t &scale_a, const f8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:672
static __device__ void Run(const f6x32_t ®_a, const int32_t scale_a, const f6x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:783
static __device__ void Run(const bf8x32_t ®_a, const int32_t &scale_a, const bf8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:709
static __device__ void Run(const bf8x32_t ®_a, const int32_t &scale_a, const f8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:746
static __device__ void Run(const f4x32_t ®_a, const int32_t scale_a, const f4x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:851
static __device__ void Run(const bf6x32_t ®_a, const int32_t scale_a, const bf6x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:817
Definition: amd_xdlops.hpp:666
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10