16 static constexpr
bool is_scale_mfma_data_type()
18 using U = element_type_t<T>;
19 return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
20 is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
23 #ifndef CK_CODE_GEN_RTC
28 static constexpr
bool is_scale_mfma_scale_type()
30 return is_same_v<T, e8m0_bexp_t>;
37 template <
typename ADataType,
typename BDataType,
typename AScaleDataType,
typename BScaleDataType>
38 static constexpr
bool scale_mfma_hw_support()
40 return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
41 is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
101 template <MfmaInstr instr>
108 static constexpr
index_t num_groups_per_blk = 4;
109 static constexpr
index_t num_regs_per_blk = 16;
110 static constexpr
index_t num_threads_per_blk = 32;
117 static constexpr
bool is_k_reduction =
false;
119 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
120 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
130 static constexpr
index_t num_groups_per_blk = 4;
131 static constexpr
index_t num_regs_per_blk = 16;
132 static constexpr
index_t num_threads_per_blk = 32;
139 static constexpr
bool is_k_reduction =
true;
141 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
142 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
152 static constexpr
index_t num_groups_per_blk = 1;
153 static constexpr
index_t num_regs_per_blk = 4;
154 static constexpr
index_t num_threads_per_blk = 16;
161 static constexpr
bool is_k_reduction =
true;
163 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
164 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
174 static constexpr
index_t num_groups_per_blk = 1;
175 static constexpr
index_t num_regs_per_blk = 4;
176 static constexpr
index_t num_threads_per_blk = 16;
183 static constexpr
bool is_k_reduction =
false;
185 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
186 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
197 static constexpr
index_t num_groups_per_blk = 1;
198 static constexpr
index_t num_regs_per_blk = 4;
199 static constexpr
index_t num_threads_per_blk = 64;
206 static constexpr
bool is_k_reduction =
false;
208 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
209 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
219 static constexpr
index_t num_groups_per_blk = 4;
220 static constexpr
index_t num_regs_per_blk = 16;
221 static constexpr
index_t num_threads_per_blk = 32;
228 static constexpr
bool is_k_reduction =
false;
230 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
231 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
241 static constexpr
index_t num_groups_per_blk = 4;
242 static constexpr
index_t num_regs_per_blk = 16;
243 static constexpr
index_t num_threads_per_blk = 32;
250 static constexpr
bool is_k_reduction =
true;
252 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
253 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
263 static constexpr
index_t num_groups_per_blk = 4;
264 static constexpr
index_t num_regs_per_blk = 16;
265 static constexpr
index_t num_threads_per_blk = 32;
272 static constexpr
bool is_k_reduction =
true;
274 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
275 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
285 static constexpr
index_t num_groups_per_blk = 1;
286 static constexpr
index_t num_regs_per_blk = 4;
287 static constexpr
index_t num_threads_per_blk = 16;
294 static constexpr
bool is_k_reduction =
true;
296 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
297 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
307 static constexpr
index_t num_groups_per_blk = 1;
308 static constexpr
index_t num_regs_per_blk = 4;
309 static constexpr
index_t num_threads_per_blk = 16;
316 static constexpr
bool is_k_reduction =
true;
318 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
319 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
329 static constexpr
index_t num_groups_per_blk = 1;
330 static constexpr
index_t num_regs_per_blk = 4;
331 static constexpr
index_t num_threads_per_blk = 16;
338 static constexpr
bool is_k_reduction =
false;
340 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
341 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
351 static constexpr
index_t num_groups_per_blk = 1;
352 static constexpr
index_t num_regs_per_blk = 4;
353 static constexpr
index_t num_threads_per_blk = 64;
360 static constexpr
bool is_k_reduction =
false;
362 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
363 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
373 static constexpr
index_t num_groups_per_blk = 4;
374 static constexpr
index_t num_regs_per_blk = 16;
375 static constexpr
index_t num_threads_per_blk = 32;
382 static constexpr
bool is_k_reduction =
true;
384 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
385 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
395 static constexpr
index_t num_groups_per_blk = 4;
396 static constexpr
index_t num_regs_per_blk = 16;
397 static constexpr
index_t num_threads_per_blk = 32;
404 static constexpr
bool is_k_reduction =
true;
406 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
407 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
417 static constexpr
index_t num_groups_per_blk = 1;
418 static constexpr
index_t num_regs_per_blk = 4;
419 static constexpr
index_t num_threads_per_blk = 16;
426 static constexpr
bool is_k_reduction =
true;
428 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
429 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
439 static constexpr
index_t num_groups_per_blk = 1;
440 static constexpr
index_t num_regs_per_blk = 4;
441 static constexpr
index_t num_threads_per_blk = 16;
448 static constexpr
bool is_k_reduction =
true;
450 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
451 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
461 static constexpr
index_t num_groups_per_blk = 4;
462 static constexpr
index_t num_regs_per_blk = 16;
463 static constexpr
index_t num_threads_per_blk = 32;
470 static constexpr
bool is_k_reduction =
true;
472 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
473 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
483 static constexpr
index_t num_groups_per_blk = 1;
484 static constexpr
index_t num_regs_per_blk = 4;
485 static constexpr
index_t num_threads_per_blk = 16;
492 static constexpr
bool is_k_reduction =
true;
494 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
495 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
505 static constexpr
index_t num_groups_per_blk = 4;
506 static constexpr
index_t num_regs_per_blk = 16;
507 static constexpr
index_t num_threads_per_blk = 32;
514 static constexpr
bool is_k_reduction =
true;
516 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
517 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
527 static constexpr
index_t num_groups_per_blk = 1;
528 static constexpr
index_t num_regs_per_blk = 4;
529 static constexpr
index_t num_threads_per_blk = 16;
536 static constexpr
bool is_k_reduction =
true;
538 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
539 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
549 static constexpr
index_t num_groups_per_blk = 4;
550 static constexpr
index_t num_regs_per_blk = 16;
551 static constexpr
index_t num_threads_per_blk = 32;
558 static constexpr
bool is_k_reduction =
true;
560 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
561 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
571 static constexpr
index_t num_groups_per_blk = 1;
572 static constexpr
index_t num_regs_per_blk = 4;
573 static constexpr
index_t num_threads_per_blk = 16;
580 static constexpr
bool is_k_reduction =
true;
582 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
583 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
593 static constexpr
index_t num_groups_per_blk = 4;
594 static constexpr
index_t num_regs_per_blk = 16;
595 static constexpr
index_t num_threads_per_blk = 32;
602 static constexpr
bool is_k_reduction =
true;
604 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
605 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
615 static constexpr
index_t num_groups_per_blk = 1;
616 static constexpr
index_t num_regs_per_blk = 4;
617 static constexpr
index_t num_threads_per_blk = 16;
624 static constexpr
bool is_k_reduction =
true;
626 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
627 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
637 static constexpr
index_t num_groups_per_blk = 4;
638 static constexpr
index_t num_regs_per_blk = 4;
639 static constexpr
index_t num_threads_per_blk = 16;
646 static constexpr
bool is_k_reduction =
true;
648 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
649 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
659 static constexpr
index_t num_groups_per_blk = 4;
660 static constexpr
index_t num_regs_per_blk = 16;
661 static constexpr
index_t num_threads_per_blk = 32;
668 static constexpr
bool is_k_reduction =
true;
670 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
671 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
681 static constexpr
index_t num_groups_per_blk = 1;
682 static constexpr
index_t num_regs_per_blk = 4;
683 static constexpr
index_t num_threads_per_blk = 16;
690 static constexpr
bool is_k_reduction =
true;
692 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
693 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
703 static constexpr
index_t num_groups_per_blk = 4;
704 static constexpr
index_t num_regs_per_blk = 16;
705 static constexpr
index_t num_threads_per_blk = 32;
712 static constexpr
bool is_k_reduction =
true;
714 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
715 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
725 static constexpr
index_t num_groups_per_blk = 1;
726 static constexpr
index_t num_regs_per_blk = 4;
727 static constexpr
index_t num_threads_per_blk = 16;
734 static constexpr
bool is_k_reduction =
true;
736 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
737 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
747 static constexpr
index_t num_groups_per_blk = 4;
748 static constexpr
index_t num_regs_per_blk = 16;
749 static constexpr
index_t num_threads_per_blk = 32;
756 static constexpr
bool is_k_reduction =
true;
758 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
759 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
769 static constexpr
index_t num_groups_per_blk = 1;
770 static constexpr
index_t num_regs_per_blk = 4;
771 static constexpr
index_t num_threads_per_blk = 16;
778 static constexpr
bool is_k_reduction =
true;
780 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
781 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
791 static constexpr
index_t num_groups_per_blk = 4;
792 static constexpr
index_t num_regs_per_blk = 16;
793 static constexpr
index_t num_threads_per_blk = 32;
800 static constexpr
bool is_k_reduction =
true;
802 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
803 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
813 static constexpr
index_t num_groups_per_blk = 1;
814 static constexpr
index_t num_regs_per_blk = 4;
815 static constexpr
index_t num_threads_per_blk = 16;
822 static constexpr
bool is_k_reduction =
true;
824 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
825 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
836 static constexpr
index_t num_groups_per_blk = 4;
837 static constexpr
index_t num_regs_per_blk = 16;
838 static constexpr
index_t num_threads_per_blk = 32;
845 static constexpr
bool is_k_reduction =
true;
848 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
849 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
860 static constexpr
index_t num_groups_per_blk = 1;
861 static constexpr
index_t num_regs_per_blk = 4;
862 static constexpr
index_t num_threads_per_blk = 16;
869 static constexpr
bool is_k_reduction =
true;
872 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
873 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
884 static constexpr
index_t num_groups_per_blk = 4;
885 static constexpr
index_t num_regs_per_blk = 16;
886 static constexpr
index_t num_threads_per_blk = 32;
893 static constexpr
bool is_k_reduction =
true;
905 __device__
void run(
const FloatA&
a,
906 const ScaleA& scale_a,
908 const ScaleB& scale_b,
912 a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
921 static constexpr
index_t num_groups_per_blk = 1;
922 static constexpr
index_t num_regs_per_blk = 4;
923 static constexpr
index_t num_threads_per_blk = 16;
930 static constexpr
bool is_k_reduction =
true;
942 __device__
void run(
const FloatA&
a,
943 const ScaleA& scale_a,
945 const ScaleB& scale_b,
950 a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
979 static constexpr
index_t num_threads_per_blk = n_per_blk;
980 static constexpr
index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size;
981 static constexpr
index_t num_input_blks = m_per_blk / num_regs_per_blk;
983 static constexpr
index_t num_groups_per_blk = 1;
986 static constexpr
bool is_k_reduction =
true;
989 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
990 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1002 static constexpr
index_t num_threads_per_blk = n_per_blk;
1003 static constexpr
index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size;
1004 static constexpr
index_t num_input_blks = m_per_blk / num_regs_per_blk;
1009 static constexpr
bool is_k_reduction =
true;
1011 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
1012 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1037 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1038 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1047 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1048 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1065 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1075 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1076 __device__
void run(
const FloatA&,
const FloatB&, FloatC&)
const
1101 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1102 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1111 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1112 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1129 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1139 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1140 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1149 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1150 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1159 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1160 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1169 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1170 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1180 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1181 __device__
void run(
const FloatA&,
const FloatB&, FloatC&)
const
1201 template <
typename base_type,
1204 typename additional_type = base_type,
1205 bool is_single_rate_mfma =
false,
1206 bool is_scale_mfma =
false>
1209 template <
typename base_type_,
1212 typename additional_type_ = base_type_,
1213 bool is_single_rate_mfma_ =
false,
1214 bool is_scale_mfma_ =
false>
1218 constexpr
auto GetMfma<double, 16, 16>()
1220 #if defined(__gfx12__)
1222 #elif defined(__gfx11__)
1230 constexpr
auto GetMfma<float, 64, 64>()
1236 constexpr
auto GetMfma<float, 32, 64>()
1242 constexpr
auto GetMfma<float, 16, 64>()
1248 constexpr
auto GetMfma<float, 8, 64>()
1254 constexpr
auto GetMfma<float, 4, 64>()
1260 constexpr
auto GetMfma<float, 32, 32>()
1266 constexpr
auto GetMfma<float, 16, 16>()
1268 #if defined(__gfx12__)
1270 #elif defined(__gfx11__)
1278 constexpr
auto GetMfma<tf32_t, 32, 32>()
1280 #if defined(__gfx12__)
1282 #elif defined(__gfx11__)
1284 #elif defined(__gfx942__)
1292 constexpr
auto GetMfma<tf32_t, 16, 16>()
1294 #if defined(__gfx12__)
1296 #elif defined(__gfx11__)
1298 #elif defined(__gfx942__)
1306 constexpr
auto GetMfma<half_t, 64, 64>()
1312 constexpr
auto GetMfma<half_t, 32, 64>()
1318 constexpr
auto GetMfma<half_t, 32, 32, half_t, false>()
1320 #if defined(__gfx950__)
1327 constexpr
auto GetMfma<half_t, 32, 32, half_t, true>()
1333 constexpr
auto GetMfma<half_t, 16, 16, half_t, false>()
1335 #if defined(__gfx12__)
1337 #elif defined(__gfx11__)
1339 #elif defined(__gfx950__)
1347 constexpr
auto GetMfma<half_t, 16, 16, half_t, true>()
1349 #if defined(__gfx12__)
1351 #elif defined(__gfx11__)
1359 constexpr
auto GetMfma<half_t, 16, 64>()
1365 constexpr
auto GetMfma<half_t, 8, 64>()
1371 constexpr
auto GetMfma<half_t, 4, 64>()
1377 constexpr
auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
1379 #if defined(__gfx950__)
1381 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1389 constexpr
auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
1391 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1399 constexpr
auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
1401 #if defined(__gfx12__)
1403 #elif defined(__gfx11__)
1405 #elif defined(__gfx950__)
1407 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1415 constexpr
auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
1417 #if defined(__gfx12__)
1419 #elif defined(__gfx11__)
1421 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1429 constexpr
auto GetMfma<int8_t, 32, 32, int8_t, false>()
1431 #if defined(__gfx950__)
1433 #elif defined(__gfx942__)
1441 constexpr
auto GetMfma<int8_t, 32, 32, int8_t, true>()
1443 #if defined(__gfx942__) || defined(__gfx950__)
1451 constexpr
auto GetMfma<int8_t, 16, 16, int8_t, false>()
1453 #if defined(__gfx12__)
1455 #elif defined(__gfx11__)
1457 #elif defined(__gfx950__)
1459 #elif defined(__gfx942__)
1467 constexpr
auto GetMfma<int8_t, 16, 16, int8_t, true>()
1469 #if defined(__gfx12__)
1471 #elif defined(__gfx11__)
1473 #elif defined(__gfx942__) || defined(__gfx950__)
1481 constexpr
auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
1487 constexpr
auto GetMfma<f8_t, 32, 32, f8_t, false, false>()
1489 #if defined(__gfx950__)
1497 constexpr
auto GetMfma<f8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
1503 constexpr
auto GetMfma<bf8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
1508 constexpr
auto GetMfma<f4_t, 32, 32, f4_t, is_single_rate_mfma, true>()
1513 constexpr
auto GetMfma<f4_t, 16, 16, f4_t, is_single_rate_mfma, true>()
1515 #if defined(__gfx12__)
1517 #elif defined(__gfx11__)
1525 constexpr
auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
1527 #if defined(__gfx12__)
1529 #elif defined(__gfx11__)
1537 constexpr
auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
1539 #if defined(__gfx12__)
1541 #elif defined(__gfx11__)
1543 #elif defined(__gfx950__)
1551 constexpr
auto GetMfma<f8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
1553 #if defined(__gfx12__)
1555 #elif defined(__gfx11__)
1563 constexpr
auto GetMfma<bf8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
1565 #if defined(__gfx12__)
1567 #elif defined(__gfx11__)
1575 constexpr
auto GetMfma<f8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
1577 #if defined(__gfx12__)
1579 #elif defined(__gfx11__)
1587 constexpr
auto GetMfma<bf8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
1589 #if defined(__gfx12__)
1591 #elif defined(__gfx11__)
1599 constexpr
auto GetMfma<f6_t, 32, 32, f6_t, is_single_rate_mfma, true>()
1604 constexpr
auto GetMfma<f6_t, 16, 16, f6_t, is_single_rate_mfma, true>()
1606 #if defined(__gfx12__)
1608 #elif defined(__gfx11__)
1615 constexpr
auto GetMfma<bf6_t, 32, 32, bf6_t, is_single_rate_mfma, true>()
1620 constexpr
auto GetMfma<bf6_t, 16, 16, bf6_t, is_single_rate_mfma, true>()
1622 #if defined(__gfx12__)
1624 #elif defined(__gfx11__)
1632 constexpr
auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
1638 constexpr
auto GetMfma<bf8_t, 32, 32, bf8_t, false, false>()
1640 #if defined(__gfx950__)
1648 constexpr
auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
1650 #if defined(__gfx12__)
1652 #elif defined(__gfx11__)
1660 constexpr
auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
1662 #if defined(__gfx12__)
1664 #elif defined(__gfx11__)
1666 #elif defined(__gfx950__)
1674 constexpr
auto GetMfma<f8_t, 32, 32, bf8_t, true, false>()
1680 constexpr
auto GetMfma<f8_t, 32, 32, bf8_t, false, false>()
1682 #if defined(__gfx950__)
1690 constexpr
auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
1692 #if defined(__gfx12__)
1694 #elif defined(__gfx11__)
1702 constexpr
auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
1704 #if defined(__gfx12__)
1706 #elif defined(__gfx11__)
1708 #elif defined(__gfx950__)
1716 constexpr
auto GetMfma<bf8_t, 32, 32, f8_t, true, false>()
1722 constexpr
auto GetMfma<bf8_t, 32, 32, f8_t, false, false>()
1724 #if defined(__gfx950__)
1732 constexpr
auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
1734 #if defined(__gfx12__)
1736 #elif defined(__gfx11__)
1744 constexpr
auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
1746 #if defined(__gfx12__)
1748 #elif defined(__gfx11__)
1750 #elif defined(__gfx950__)
1761 is_single_rate_mfma,
1762 is_scale_mfma>()>{};
1768 "wrong! num_regs_per_blk");
1771 "n_per_blk != num_threads_per_blk");
1772 #if defined(__gfx11__)
1773 if constexpr(MPerXdlops == 16 && NPerXdlops == 16)
1777 "m_per_blk != num_input_blks * num_regs_per_blk");
1782 "m_per_blk != num_input_blks * num_regs_per_blk");
1787 "incorrect num_output_blks");
1791 "num_regs_per_blk incorrect");
1795 "is_k_reduction wrong!");
1800 static_assert(NPerXdlops >= MPerXdlops,
"only support ABroadcast");
1813 template <
typename base_type,
1817 typename additional_type = base_type,
1818 bool TransposeC =
false,
1819 bool is_scale_mfma =
false>
1836 return MPerXdlops * NPerXdlops /
1842 static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
1844 "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1846 static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
1848 "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1849 #if defined(__HIP_DEVICE_COMPILE__)
1850 static_assert(KPack %
mfma_instr.k_per_blk == 0,
"KPack should be a multiple of k_per_blk");
1856 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1857 __host__ __device__
static constexpr
auto
1860 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1861 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1862 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1863 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1867 c_desc_m0_n0_m1_n1_m2_n2,
1892 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1894 const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1896 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1897 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1898 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1899 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1900 const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I4);
1901 const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I5);
1905 c_desc_m0_n0_m1_n1_m2_n2,
1936 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1937 __host__ __device__
static constexpr
auto
1940 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1941 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1942 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1943 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1947 c_desc_m0_n0_m1_n1_m2_n2,
1970 template <
typename CDesc_G_M0_N0_M1_N1_M2_N2>
1972 const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
1974 const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1975 const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1976 const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1977 const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1978 const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I4);
1982 c_desc_g_m0_n0_m1_n1_m2_n2,
2014 template <
class FloatA,
class FloatB,
class FloatC>
2015 __device__
void Run(
const FloatA& p_a_wave,
const FloatB& p_b_wave, FloatC& p_c_thread)
const
2024 "base_type must be double, float, tf32_t, half, bfloat16, int8_t, f8_t or bf8_t!");
2027 if constexpr(!TransposeC)
2029 mfma_instr.template run<MPerXdlops, NPerXdlops>(
2030 p_a_wave[k], p_b_wave[k], p_c_thread);
2034 mfma_instr.template run<MPerXdlops, NPerXdlops>(
2035 p_b_wave[k], p_a_wave[k], p_c_thread);
2047 __device__
void Run(
const FloatA& p_a_wave,
2048 const ScaleA& a_scale_thread,
2049 const FloatB& p_b_wave,
2050 const ScaleB& b_scale_thread,
2051 FloatC& p_c_thread)
const
2054 if constexpr(!TransposeC)
2056 mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
2057 p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
2061 mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
2062 p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
2080 const auto blk_idx =
2081 threadidx_to_blk_idx_adaptor.CalculateBottomIndex(
make_multi_index(laneId));
2083 const auto blk_id = blk_idx[
I1];
2084 const auto blk_td = blk_idx[
I2];
2089 template <
bool SwizzleA>
2093 if constexpr(SwizzleA)
2095 laneId = ((laneId & 1) << 3) | (laneId >> 1);
2103 const auto blk_idx =
2104 threadidx_to_blk_idx_adaptor.CalculateBottomIndex(
make_multi_index(laneId));
2106 const auto blk_id = blk_idx[
I1];
2107 const auto blk_td = blk_idx[
I2];
2115 #if defined(__gfx11__)
2116 const auto blk_idx = GetGfx11InputBlkIdx<!TransposeC>();
2121 const auto blk_id = blk_idx[
I0];
2122 const auto blk_td = blk_idx[
I1];
2137 #if defined(__gfx11__)
2138 const auto blk_idx = GetGfx11InputBlkIdx<TransposeC>();
2143 const auto blk_id = blk_idx[
I0];
2144 const auto blk_td = blk_idx[
I1];
2160 const auto blk_id = blk_idx[
I0];
2161 const auto blk_td = blk_idx[
I1];
2166 return TransposeC ?
CIndex{n_offset, m_offset} :
CIndex{m_offset, n_offset};
2173 const auto blk_id = blk_idx[
I0];
2174 const auto blk_td = blk_idx[
I1];
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
MfmaInstr
Definition: xdlops_gemm.hpp:45
@ wmma_f32_16x16x16_bf16_gfx12
@ mfma_f32_32x32x64f8f6f4
@ wmma_unsupport_16x16_gfx11
@ wmma_i32_16x16x16_iu8_gfx12
@ mfma_scale_f32_32x32x64f8f6f4
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_unsupport_16x16_gfx12
@ mfma_f32_16x16x16bf16_1k
@ wmma_f32_16x16x16_f8f8_gfx12
@ mfma_scale_f32_16x16x128f8f6f4
@ mfma_f32_16x16x32bf8bf8
@ mfma_f32_16x16x128f8f6f4
@ mfma_f32_32x32x16bf8bf8
@ mfma_f32_32x32x8bf16_1k
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
typename packed_type_info< T >::element_type element_type_t
Definition: data_type.hpp:408
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_i32_16x16x16_iu8_gfx12
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_f32_16x16x16_f8f8_gfx12
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
__host__ constexpr __device__ MfmaSelector()
Definition: xdlops_gemm.hpp:1764
static constexpr bool IsABroadcast()
Definition: xdlops_gemm.hpp:1798
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1810
static constexpr auto GetMfma()
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1804
Definition: sequence.hpp:43
Definition: xdlops_gemm.hpp:1821
static constexpr auto mfma_instr
Definition: xdlops_gemm.hpp:2197
__host__ constexpr __device__ XdlopsGemm()
Definition: xdlops_gemm.hpp:1840
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:2134
static __device__ auto GetBlkIdx()
Definition: xdlops_gemm.hpp:2069
__device__ static constexpr __host__ index_t GetRegSizePerXdlops()
Definition: xdlops_gemm.hpp:2007
static constexpr auto I2
Definition: xdlops_gemm.hpp:1824
static constexpr __device__ index_t GetNumBlks()
Definition: xdlops_gemm.hpp:1832
static __device__ auto GetLaneId()
Definition: xdlops_gemm.hpp:2067
static constexpr auto K0PerXdlops
Definition: xdlops_gemm.hpp:2201
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1893
static constexpr __device__ index_t GetNumXdlops()
Definition: xdlops_gemm.hpp:1834
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:2112
static constexpr bool is_single_rate_mfma
Definition: xdlops_gemm.hpp:2182
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition: xdlops_gemm.hpp:2169
static constexpr __device__ index_t GetWaveSize()
Definition: xdlops_gemm.hpp:2012
static __device__ auto GetGfx11InputBlkIdx()
Definition: xdlops_gemm.hpp:2090
static constexpr auto I5
Definition: xdlops_gemm.hpp:1827
static constexpr auto I3
Definition: xdlops_gemm.hpp:1825
static constexpr auto I0
Definition: xdlops_gemm.hpp:1822
__device__ void Run(const FloatA &p_a_wave, const ScaleA &a_scale_thread, const FloatB &p_b_wave, const ScaleB &b_scale_thread, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:2047
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1858
__host__ static constexpr __device__ auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1971
static constexpr auto I1
Definition: xdlops_gemm.hpp:1823
static constexpr auto K1PerXdlops
Definition: xdlops_gemm.hpp:2200
static constexpr auto KPerXdlops
Definition: xdlops_gemm.hpp:2199
static constexpr auto I4
Definition: xdlops_gemm.hpp:1826
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:2015
static constexpr auto mfma
Definition: xdlops_gemm.hpp:2190
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition: xdlops_gemm.hpp:2156
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1938
__host__ static constexpr __device__ auto GetCM0M1M2NThreadBlkLengths()
Definition: xdlops_gemm.hpp:2203
Definition: integral_constant.hpp:20
Definition: amd_xdlops.hpp:1202
Definition: amd_xdlops.hpp:303
Definition: amd_xdlops.hpp:193
Definition: amd_xdlops.hpp:70
Definition: amd_xdlops.hpp:269
Definition: amd_xdlops.hpp:1483
Definition: amd_xdlops.hpp:1609
Definition: amd_xdlops.hpp:159
Definition: amd_xdlops.hpp:1546
Definition: amd_xdlops.hpp:1420
Definition: amd_xdlops.hpp:207
Definition: amd_xdlops.hpp:56
Definition: amd_xdlops.hpp:331
Definition: amd_xdlops.hpp:1641
Definition: amd_xdlops.hpp:249
Definition: amd_xdlops.hpp:1451
Definition: amd_xdlops.hpp:1577
Definition: amd_xdlops.hpp:139
Definition: amd_xdlops.hpp:1514
Definition: amd_xdlops.hpp:1388
Definition: amd_xdlops.hpp:15
Definition: amd_xdlops.hpp:42
Definition: amd_xdlops.hpp:317
Definition: amd_xdlops.hpp:112
Definition: amd_xdlops.hpp:1661
Definition: amd_xdlops.hpp:481
Definition: amd_xdlops.hpp:289
Definition: amd_xdlops.hpp:179
Definition: amd_xdlops.hpp:84
Definition: amd_xdlops.hpp:221
Definition: amd_xdlops.hpp:461
Definition: amd_xdlops.hpp:364
Definition: amd_xdlops.hpp:442
Definition: amd_xdlops.hpp:403
Definition: amd_xdlops.hpp:423
Definition: amd_xdlops.hpp:383
Definition: amd_xdlops.hpp:345
Definition: amd_xdlops.hpp:886
Definition: amd_xdlops.hpp:666
Definition: amd_wmma.hpp:297
Definition: amd_wmma.hpp:50
Definition: amd_wmma.hpp:418
Definition: amd_wmma.hpp:394
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
Definition: amd_wmma.hpp:370
Definition: amd_wmma.hpp:346
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:873
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:451
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:319
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:186
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:429
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:737
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:825
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:297
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:781
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:693
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:341
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:164
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:495
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:990
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:385
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:715
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:803
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:275
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:759
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:671
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:120
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:142
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:473
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:231
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1012
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:849
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:407
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:253
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:209
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:363
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:649
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:539
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:583
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:627
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:561
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:605
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:517
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:942
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:905
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1048
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1112
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1170
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1160
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1038
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1102
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1150
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1140
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1065
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1129
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition: xdlops_gemm.hpp:1076
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition: xdlops_gemm.hpp:1181
Definition: xdlops_gemm.hpp:1020
static constexpr index_t n_per_blk
Definition: xdlops_gemm.hpp:1029
static constexpr index_t group_size
Definition: xdlops_gemm.hpp:1021
static constexpr index_t m_per_blk
Definition: xdlops_gemm.hpp:1028
static constexpr bool is_k_reduction
Definition: xdlops_gemm.hpp:1031
static constexpr index_t num_threads_per_blk
Definition: xdlops_gemm.hpp:1024
static constexpr index_t num_output_blks
Definition: xdlops_gemm.hpp:1027
static constexpr index_t wave_size
Definition: xdlops_gemm.hpp:1025
static constexpr index_t num_input_blks
Definition: xdlops_gemm.hpp:1026
static constexpr index_t num_groups_per_blk
Definition: xdlops_gemm.hpp:1022
static constexpr index_t num_regs_per_blk
Definition: xdlops_gemm.hpp:1023
static constexpr index_t k_per_blk
Definition: xdlops_gemm.hpp:1030
Definition: xdlops_gemm.hpp:1084
static constexpr index_t n_per_blk
Definition: xdlops_gemm.hpp:1093
static constexpr index_t group_size
Definition: xdlops_gemm.hpp:1085
static constexpr index_t num_output_blks
Definition: xdlops_gemm.hpp:1091
static constexpr index_t m_per_blk
Definition: xdlops_gemm.hpp:1092
static constexpr index_t num_threads_per_blk
Definition: xdlops_gemm.hpp:1088
static constexpr bool is_k_reduction
Definition: xdlops_gemm.hpp:1095
static constexpr index_t num_regs_per_blk
Definition: xdlops_gemm.hpp:1087
static constexpr index_t num_groups_per_blk
Definition: xdlops_gemm.hpp:1086
static constexpr index_t num_input_blks
Definition: xdlops_gemm.hpp:1090
static constexpr index_t wave_size
Definition: xdlops_gemm.hpp:1089
static constexpr index_t k_per_blk
Definition: xdlops_gemm.hpp:1094
Definition: xdlops_gemm.hpp:102
Definition: functional2.hpp:33