16 static constexpr
bool is_scale_mfma_data_type()
18 using U = element_type_t<T>;
19 return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
20 is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
27 static constexpr
bool is_scale_mfma_scale_type()
29 return is_same_v<T, e8m0_bexp_t>;
35 template <
typename ADataType,
typename BDataType,
typename AScaleDataType,
typename BScaleDataType>
36 static constexpr
bool scale_mfma_hw_support()
38 return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
39 is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
97 template <MfmaInstr instr>
104 static constexpr
index_t num_groups_per_blk = 4;
105 static constexpr
index_t num_regs_per_blk = 16;
106 static constexpr
index_t num_threads_per_blk = 32;
113 static constexpr
bool is_k_reduction =
false;
115 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
116 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
126 static constexpr
index_t num_groups_per_blk = 4;
127 static constexpr
index_t num_regs_per_blk = 16;
128 static constexpr
index_t num_threads_per_blk = 32;
135 static constexpr
bool is_k_reduction =
true;
137 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
138 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
148 static constexpr
index_t num_groups_per_blk = 1;
149 static constexpr
index_t num_regs_per_blk = 4;
150 static constexpr
index_t num_threads_per_blk = 16;
157 static constexpr
bool is_k_reduction =
true;
159 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
160 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
170 static constexpr
index_t num_groups_per_blk = 1;
171 static constexpr
index_t num_regs_per_blk = 4;
172 static constexpr
index_t num_threads_per_blk = 16;
179 static constexpr
bool is_k_reduction =
false;
181 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
182 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
193 static constexpr
index_t num_groups_per_blk = 1;
194 static constexpr
index_t num_regs_per_blk = 4;
195 static constexpr
index_t num_threads_per_blk = 64;
202 static constexpr
bool is_k_reduction =
false;
204 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
205 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
215 static constexpr
index_t num_groups_per_blk = 4;
216 static constexpr
index_t num_regs_per_blk = 16;
217 static constexpr
index_t num_threads_per_blk = 32;
224 static constexpr
bool is_k_reduction =
false;
226 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
227 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
237 static constexpr
index_t num_groups_per_blk = 4;
238 static constexpr
index_t num_regs_per_blk = 16;
239 static constexpr
index_t num_threads_per_blk = 32;
246 static constexpr
bool is_k_reduction =
true;
248 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
249 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
259 static constexpr
index_t num_groups_per_blk = 4;
260 static constexpr
index_t num_regs_per_blk = 16;
261 static constexpr
index_t num_threads_per_blk = 32;
268 static constexpr
bool is_k_reduction =
true;
270 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
271 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
281 static constexpr
index_t num_groups_per_blk = 1;
282 static constexpr
index_t num_regs_per_blk = 4;
283 static constexpr
index_t num_threads_per_blk = 16;
290 static constexpr
bool is_k_reduction =
true;
292 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
293 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
303 static constexpr
index_t num_groups_per_blk = 1;
304 static constexpr
index_t num_regs_per_blk = 4;
305 static constexpr
index_t num_threads_per_blk = 16;
312 static constexpr
bool is_k_reduction =
true;
314 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
315 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
325 static constexpr
index_t num_groups_per_blk = 1;
326 static constexpr
index_t num_regs_per_blk = 4;
327 static constexpr
index_t num_threads_per_blk = 16;
334 static constexpr
bool is_k_reduction =
false;
336 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
337 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
347 static constexpr
index_t num_groups_per_blk = 1;
348 static constexpr
index_t num_regs_per_blk = 4;
349 static constexpr
index_t num_threads_per_blk = 64;
356 static constexpr
bool is_k_reduction =
false;
358 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
359 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
369 static constexpr
index_t num_groups_per_blk = 4;
370 static constexpr
index_t num_regs_per_blk = 16;
371 static constexpr
index_t num_threads_per_blk = 32;
378 static constexpr
bool is_k_reduction =
true;
380 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
381 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
391 static constexpr
index_t num_groups_per_blk = 4;
392 static constexpr
index_t num_regs_per_blk = 16;
393 static constexpr
index_t num_threads_per_blk = 32;
400 static constexpr
bool is_k_reduction =
true;
402 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
403 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
413 static constexpr
index_t num_groups_per_blk = 1;
414 static constexpr
index_t num_regs_per_blk = 4;
415 static constexpr
index_t num_threads_per_blk = 16;
422 static constexpr
bool is_k_reduction =
true;
424 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
425 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
435 static constexpr
index_t num_groups_per_blk = 1;
436 static constexpr
index_t num_regs_per_blk = 4;
437 static constexpr
index_t num_threads_per_blk = 16;
444 static constexpr
bool is_k_reduction =
true;
446 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
447 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
457 static constexpr
index_t num_groups_per_blk = 4;
458 static constexpr
index_t num_regs_per_blk = 16;
459 static constexpr
index_t num_threads_per_blk = 32;
466 static constexpr
bool is_k_reduction =
true;
468 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
469 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
479 static constexpr
index_t num_groups_per_blk = 1;
480 static constexpr
index_t num_regs_per_blk = 4;
481 static constexpr
index_t num_threads_per_blk = 16;
488 static constexpr
bool is_k_reduction =
true;
490 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
491 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
501 static constexpr
index_t num_groups_per_blk = 4;
502 static constexpr
index_t num_regs_per_blk = 16;
503 static constexpr
index_t num_threads_per_blk = 32;
510 static constexpr
bool is_k_reduction =
true;
512 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
513 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
523 static constexpr
index_t num_groups_per_blk = 1;
524 static constexpr
index_t num_regs_per_blk = 4;
525 static constexpr
index_t num_threads_per_blk = 16;
532 static constexpr
bool is_k_reduction =
true;
534 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
535 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
545 static constexpr
index_t num_groups_per_blk = 4;
546 static constexpr
index_t num_regs_per_blk = 16;
547 static constexpr
index_t num_threads_per_blk = 32;
554 static constexpr
bool is_k_reduction =
true;
556 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
557 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
567 static constexpr
index_t num_groups_per_blk = 1;
568 static constexpr
index_t num_regs_per_blk = 4;
569 static constexpr
index_t num_threads_per_blk = 16;
576 static constexpr
bool is_k_reduction =
true;
578 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
579 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
589 static constexpr
index_t num_groups_per_blk = 4;
590 static constexpr
index_t num_regs_per_blk = 16;
591 static constexpr
index_t num_threads_per_blk = 32;
598 static constexpr
bool is_k_reduction =
true;
600 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
601 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
611 static constexpr
index_t num_groups_per_blk = 1;
612 static constexpr
index_t num_regs_per_blk = 4;
613 static constexpr
index_t num_threads_per_blk = 16;
620 static constexpr
bool is_k_reduction =
true;
622 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
623 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
633 static constexpr
index_t num_groups_per_blk = 4;
634 static constexpr
index_t num_regs_per_blk = 4;
635 static constexpr
index_t num_threads_per_blk = 16;
642 static constexpr
bool is_k_reduction =
true;
644 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
645 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
655 static constexpr
index_t num_groups_per_blk = 4;
656 static constexpr
index_t num_regs_per_blk = 16;
657 static constexpr
index_t num_threads_per_blk = 32;
664 static constexpr
bool is_k_reduction =
true;
666 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
667 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
677 static constexpr
index_t num_groups_per_blk = 1;
678 static constexpr
index_t num_regs_per_blk = 4;
679 static constexpr
index_t num_threads_per_blk = 16;
686 static constexpr
bool is_k_reduction =
true;
688 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
689 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
699 static constexpr
index_t num_groups_per_blk = 4;
700 static constexpr
index_t num_regs_per_blk = 16;
701 static constexpr
index_t num_threads_per_blk = 32;
708 static constexpr
bool is_k_reduction =
true;
710 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
711 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
721 static constexpr
index_t num_groups_per_blk = 1;
722 static constexpr
index_t num_regs_per_blk = 4;
723 static constexpr
index_t num_threads_per_blk = 16;
730 static constexpr
bool is_k_reduction =
true;
732 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
733 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
743 static constexpr
index_t num_groups_per_blk = 4;
744 static constexpr
index_t num_regs_per_blk = 16;
745 static constexpr
index_t num_threads_per_blk = 32;
752 static constexpr
bool is_k_reduction =
true;
754 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
755 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
765 static constexpr
index_t num_groups_per_blk = 1;
766 static constexpr
index_t num_regs_per_blk = 4;
767 static constexpr
index_t num_threads_per_blk = 16;
774 static constexpr
bool is_k_reduction =
true;
776 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
777 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
787 static constexpr
index_t num_groups_per_blk = 4;
788 static constexpr
index_t num_regs_per_blk = 16;
789 static constexpr
index_t num_threads_per_blk = 32;
796 static constexpr
bool is_k_reduction =
true;
798 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
799 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
809 static constexpr
index_t num_groups_per_blk = 1;
810 static constexpr
index_t num_regs_per_blk = 4;
811 static constexpr
index_t num_threads_per_blk = 16;
818 static constexpr
bool is_k_reduction =
true;
820 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
821 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
832 static constexpr
index_t num_groups_per_blk = 4;
833 static constexpr
index_t num_regs_per_blk = 16;
834 static constexpr
index_t num_threads_per_blk = 32;
841 static constexpr
bool is_k_reduction =
true;
844 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
845 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
856 static constexpr
index_t num_groups_per_blk = 1;
857 static constexpr
index_t num_regs_per_blk = 4;
858 static constexpr
index_t num_threads_per_blk = 16;
865 static constexpr
bool is_k_reduction =
true;
868 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
869 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
880 static constexpr
index_t num_groups_per_blk = 4;
881 static constexpr
index_t num_regs_per_blk = 16;
882 static constexpr
index_t num_threads_per_blk = 32;
889 static constexpr
bool is_k_reduction =
true;
901 __device__
void run(
const FloatA&
a,
902 const ScaleA& scale_a,
904 const ScaleB& scale_b,
908 a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
917 static constexpr
index_t num_groups_per_blk = 1;
918 static constexpr
index_t num_regs_per_blk = 4;
919 static constexpr
index_t num_threads_per_blk = 16;
926 static constexpr
bool is_k_reduction =
true;
938 __device__
void run(
const FloatA&
a,
939 const ScaleA& scale_a,
941 const ScaleB& scale_b,
946 a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
969 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
970 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
979 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
980 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
997 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1007 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1008 __device__
void run(
const FloatA&,
const FloatB&, FloatC&)
const
1033 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1034 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1043 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1044 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1061 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1071 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1072 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1081 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1082 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1091 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1092 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1101 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1102 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1112 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1113 __device__
void run(
const FloatA&,
const FloatB&, FloatC&)
const
1119 template <
typename base_type,
1122 typename additional_type = base_type,
1123 bool is_single_rate_mfma =
false,
1124 bool is_scale_mfma =
false>
1127 template <
typename base_type_,
1130 typename additional_type_ = base_type_,
1131 bool is_single_rate_mfma_ =
false,
1132 bool is_scale_mfma_ =
false>
1136 constexpr
auto GetMfma<double, 16, 16>()
1138 #if defined(__gfx12__)
1140 #elif defined(__gfx11__)
1148 constexpr
auto GetMfma<float, 64, 64>()
1154 constexpr
auto GetMfma<float, 32, 64>()
1160 constexpr
auto GetMfma<float, 16, 64>()
1166 constexpr
auto GetMfma<float, 8, 64>()
1172 constexpr
auto GetMfma<float, 4, 64>()
1178 constexpr
auto GetMfma<float, 32, 32>()
1184 constexpr
auto GetMfma<float, 16, 16>()
1186 #if defined(__gfx12__)
1188 #elif defined(__gfx11__)
1196 constexpr
auto GetMfma<half_t, 64, 64>()
1202 constexpr
auto GetMfma<half_t, 32, 64>()
1208 constexpr
auto GetMfma<half_t, 32, 32, half_t, false>()
1210 #if defined(__gfx950__)
1217 constexpr
auto GetMfma<half_t, 32, 32, half_t, true>()
1223 constexpr
auto GetMfma<half_t, 16, 16, half_t, false>()
1225 #if defined(__gfx12__)
1227 #elif defined(__gfx11__)
1229 #elif defined(__gfx950__)
1237 constexpr
auto GetMfma<half_t, 16, 16, half_t, true>()
1239 #if defined(__gfx12__)
1241 #elif defined(__gfx11__)
1249 constexpr
auto GetMfma<half_t, 16, 64>()
1255 constexpr
auto GetMfma<half_t, 8, 64>()
1261 constexpr
auto GetMfma<half_t, 4, 64>()
1267 constexpr
auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
1269 #if defined(__gfx950__)
1271 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1279 constexpr
auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
1281 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1289 constexpr
auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
1291 #if defined(__gfx12__)
1293 #elif defined(__gfx11__)
1295 #elif defined(__gfx950__)
1297 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1305 constexpr
auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
1307 #if defined(__gfx12__)
1309 #elif defined(__gfx11__)
1311 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1319 constexpr
auto GetMfma<int8_t, 32, 32, int8_t, false>()
1321 #if defined(__gfx950__)
1323 #elif defined(__gfx942__)
1331 constexpr
auto GetMfma<int8_t, 32, 32, int8_t, true>()
1333 #if defined(__gfx942__) || defined(__gfx950__)
1341 constexpr
auto GetMfma<int8_t, 16, 16, int8_t, false>()
1343 #if defined(__gfx12__)
1345 #elif defined(__gfx11__)
1347 #elif defined(__gfx950__)
1349 #elif defined(__gfx942__)
1357 constexpr
auto GetMfma<int8_t, 16, 16, int8_t, true>()
1359 #if defined(__gfx12__)
1361 #elif defined(__gfx11__)
1363 #elif defined(__gfx942__) || defined(__gfx950__)
1371 constexpr
auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
1377 constexpr
auto GetMfma<f8_t, 32, 32, f8_t, false, false>()
1379 #if defined(__gfx950__)
1387 constexpr
auto GetMfma<f8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
1393 constexpr
auto GetMfma<bf8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
1398 constexpr
auto GetMfma<f4_t, 32, 32, f4_t, is_single_rate_mfma, true>()
1403 constexpr
auto GetMfma<f4_t, 16, 16, f4_t, is_single_rate_mfma, true>()
1405 #if defined(__gfx12__)
1407 #elif defined(__gfx11__)
1415 constexpr
auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
1417 #if defined(__gfx12__)
1419 #elif defined(__gfx11__)
1427 constexpr
auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
1429 #if defined(__gfx12__)
1431 #elif defined(__gfx11__)
1433 #elif defined(__gfx950__)
1441 constexpr
auto GetMfma<f8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
1443 #if defined(__gfx12__)
1445 #elif defined(__gfx11__)
1453 constexpr
auto GetMfma<bf8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
1455 #if defined(__gfx12__)
1457 #elif defined(__gfx11__)
1465 constexpr
auto GetMfma<f8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
1467 #if defined(__gfx12__)
1469 #elif defined(__gfx11__)
1477 constexpr
auto GetMfma<bf8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
1479 #if defined(__gfx12__)
1481 #elif defined(__gfx11__)
1489 constexpr
auto GetMfma<f6_t, 32, 32, f6_t, is_single_rate_mfma, true>()
1494 constexpr
auto GetMfma<f6_t, 16, 16, f6_t, is_single_rate_mfma, true>()
1496 #if defined(__gfx12__)
1498 #elif defined(__gfx11__)
1505 constexpr
auto GetMfma<bf6_t, 32, 32, bf6_t, is_single_rate_mfma, true>()
1510 constexpr
auto GetMfma<bf6_t, 16, 16, bf6_t, is_single_rate_mfma, true>()
1512 #if defined(__gfx12__)
1514 #elif defined(__gfx11__)
1522 constexpr
auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
1528 constexpr
auto GetMfma<bf8_t, 32, 32, bf8_t, false, false>()
1530 #if defined(__gfx950__)
1538 constexpr
auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
1540 #if defined(__gfx12__)
1542 #elif defined(__gfx11__)
1550 constexpr
auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
1552 #if defined(__gfx12__)
1554 #elif defined(__gfx11__)
1556 #elif defined(__gfx950__)
1564 constexpr
auto GetMfma<f8_t, 32, 32, bf8_t, true, false>()
1570 constexpr
auto GetMfma<f8_t, 32, 32, bf8_t, false, false>()
1572 #if defined(__gfx950__)
1580 constexpr
auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
1582 #if defined(__gfx12__)
1584 #elif defined(__gfx11__)
1592 constexpr
auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
1594 #if defined(__gfx12__)
1596 #elif defined(__gfx11__)
1598 #elif defined(__gfx950__)
1606 constexpr
auto GetMfma<bf8_t, 32, 32, f8_t, true, false>()
1612 constexpr
auto GetMfma<bf8_t, 32, 32, f8_t, false, false>()
1614 #if defined(__gfx950__)
1622 constexpr
auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
1624 #if defined(__gfx12__)
1626 #elif defined(__gfx11__)
1634 constexpr
auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
1636 #if defined(__gfx12__)
1638 #elif defined(__gfx11__)
1640 #elif defined(__gfx950__)
1651 is_single_rate_mfma,
1652 is_scale_mfma>()>{};
1658 "wrong! num_regs_per_blk");
1661 "n_per_blk != num_threads_per_blk");
1662 #if defined(__gfx11__)
1663 if constexpr(MPerXdlops == 16 && NPerXdlops == 16)
1667 "m_per_blk != num_input_blks * num_regs_per_blk");
1672 "m_per_blk != num_input_blks * num_regs_per_blk");
1677 "incorrect num_output_blks");
1681 "num_regs_per_blk incorrect");
1685 "is_k_reduction wrong!");
1690 static_assert(NPerXdlops >= MPerXdlops,
"only support ABroadcast");
1703 template <
typename base_type,
1707 typename additional_type = base_type,
1708 bool TransposeC =
false,
1709 bool is_scale_mfma =
false>
1726 return MPerXdlops * NPerXdlops /
1732 static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
1734 "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1736 static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
1738 "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1739 #if defined(__HIP_DEVICE_COMPILE__)
1740 static_assert(KPack %
mfma_instr.k_per_blk == 0,
"KPack should be a multiple of k_per_blk");
1746 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1747 __host__ __device__
static constexpr
auto
1750 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1751 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1752 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1753 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1757 c_desc_m0_n0_m1_n1_m2_n2,
1782 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1784 const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1786 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1787 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1788 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1789 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1790 const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I4);
1791 const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I5);
1795 c_desc_m0_n0_m1_n1_m2_n2,
1826 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1827 __host__ __device__
static constexpr
auto
1830 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1831 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1832 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1833 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1837 c_desc_m0_n0_m1_n1_m2_n2,
1860 template <
typename CDesc_G_M0_N0_M1_N1_M2_N2>
1862 const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
1864 const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1865 const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1866 const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1867 const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1868 const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I4);
1872 c_desc_g_m0_n0_m1_n1_m2_n2,
1899 return MPerXdlops * NPerXdlops /
mfma_instr.wave_size;
1904 template <
class FloatA,
class FloatB,
class FloatC>
1905 __device__
void Run(
const FloatA& p_a_wave,
const FloatB& p_b_wave, FloatC& p_c_thread)
const
1914 "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
1917 if constexpr(!TransposeC)
1919 mfma_instr.template run<MPerXdlops, NPerXdlops>(
1920 p_a_wave[k], p_b_wave[k], p_c_thread);
1924 mfma_instr.template run<MPerXdlops, NPerXdlops>(
1925 p_b_wave[k], p_a_wave[k], p_c_thread);
1937 __device__
void Run(
const FloatA& p_a_wave,
1938 const ScaleA& a_scale_thread,
1939 const FloatB& p_b_wave,
1940 const ScaleB& b_scale_thread,
1941 FloatC& p_c_thread)
const
1944 if constexpr(!TransposeC)
1946 mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
1947 p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
1951 mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
1952 p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
1970 const auto blk_idx =
1971 threadidx_to_blk_idx_adaptor.CalculateBottomIndex(
make_multi_index(laneId));
1973 const auto blk_id = blk_idx[
I1];
1974 const auto blk_td = blk_idx[
I2];
1979 template <
bool SwizzleA>
1983 if constexpr(SwizzleA)
1985 laneId = ((laneId & 1) << 3) | (laneId >> 1);
1993 const auto blk_idx =
1994 threadidx_to_blk_idx_adaptor.CalculateBottomIndex(
make_multi_index(laneId));
1996 const auto blk_id = blk_idx[
I1];
1997 const auto blk_td = blk_idx[
I2];
2005 #if defined(__gfx11__)
2006 const auto blk_idx = GetGfx11InputBlkIdx<!TransposeC>();
2011 const auto blk_id = blk_idx[
I0];
2012 const auto blk_td = blk_idx[
I1];
2027 #if defined(__gfx11__)
2028 const auto blk_idx = GetGfx11InputBlkIdx<TransposeC>();
2033 const auto blk_id = blk_idx[
I0];
2034 const auto blk_td = blk_idx[
I1];
2050 const auto blk_id = blk_idx[
I0];
2051 const auto blk_td = blk_idx[
I1];
2056 return TransposeC ?
CIndex{n_offset, m_offset} :
CIndex{m_offset, n_offset};
2063 const auto blk_id = blk_idx[
I0];
2064 const auto blk_td = blk_idx[
I1];
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
MfmaInstr
Definition: xdlops_gemm.hpp:43
@ wmma_f32_16x16x16_bf16_gfx12
@ mfma_f32_32x32x64f8f6f4
@ wmma_unsupport_16x16_gfx11
@ wmma_i32_16x16x16_iu8_gfx12
@ mfma_scale_f32_32x32x64f8f6f4
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_unsupport_16x16_gfx12
@ mfma_f32_16x16x16bf16_1k
@ wmma_f32_16x16x16_f8f8_gfx12
@ mfma_scale_f32_16x16x128f8f6f4
@ mfma_f32_16x16x32bf8bf8
@ mfma_f32_16x16x128f8f6f4
@ mfma_f32_32x32x16bf8bf8
@ mfma_f32_32x32x8bf16_1k
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
typename packed_type_info< T >::element_type element_type_t
Definition: data_type.hpp:405
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_i32_16x16x16_iu8_gfx12
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_f32_16x16x16_f8f8_gfx12
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: xdlops_gemm.hpp:1126
__host__ constexpr __device__ MfmaSelector()
Definition: xdlops_gemm.hpp:1654
static constexpr bool IsABroadcast()
Definition: xdlops_gemm.hpp:1688
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1700
static constexpr auto GetMfma()
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1647
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1694
Definition: sequence.hpp:43
Definition: xdlops_gemm.hpp:1711
static constexpr auto mfma_instr
Definition: xdlops_gemm.hpp:2087
__host__ constexpr __device__ XdlopsGemm()
Definition: xdlops_gemm.hpp:1730
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:2024
static __device__ auto GetBlkIdx()
Definition: xdlops_gemm.hpp:1959
__device__ static constexpr __host__ index_t GetRegSizePerXdlops()
Definition: xdlops_gemm.hpp:1897
static constexpr auto I2
Definition: xdlops_gemm.hpp:1714
static constexpr __device__ index_t GetNumBlks()
Definition: xdlops_gemm.hpp:1722
static __device__ auto GetLaneId()
Definition: xdlops_gemm.hpp:1957
static constexpr auto K0PerXdlops
Definition: xdlops_gemm.hpp:2091
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1783
static constexpr __device__ index_t GetNumXdlops()
Definition: xdlops_gemm.hpp:1724
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:2002
static constexpr bool is_single_rate_mfma
Definition: xdlops_gemm.hpp:2072
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition: xdlops_gemm.hpp:2059
static constexpr __device__ index_t GetWaveSize()
Definition: xdlops_gemm.hpp:1902
static __device__ auto GetGfx11InputBlkIdx()
Definition: xdlops_gemm.hpp:1980
static constexpr auto I5
Definition: xdlops_gemm.hpp:1717
static constexpr auto I3
Definition: xdlops_gemm.hpp:1715
static constexpr auto I0
Definition: xdlops_gemm.hpp:1712
__device__ void Run(const FloatA &p_a_wave, const ScaleA &a_scale_thread, const FloatB &p_b_wave, const ScaleB &b_scale_thread, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:1937
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1748
__host__ static constexpr __device__ auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1861
static constexpr auto I1
Definition: xdlops_gemm.hpp:1713
static constexpr auto K1PerXdlops
Definition: xdlops_gemm.hpp:2090
static constexpr auto KPerXdlops
Definition: xdlops_gemm.hpp:2089
static constexpr auto I4
Definition: xdlops_gemm.hpp:1716
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:1905
static constexpr auto mfma
Definition: xdlops_gemm.hpp:2080
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition: xdlops_gemm.hpp:2046
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1828
__host__ static constexpr __device__ auto GetCM0M1M2NThreadBlkLengths()
Definition: xdlops_gemm.hpp:2093
Definition: integral_constant.hpp:20
Definition: amd_xdlops.hpp:1202
Definition: amd_xdlops.hpp:303
Definition: amd_xdlops.hpp:193
Definition: amd_xdlops.hpp:70
Definition: amd_xdlops.hpp:269
Definition: amd_xdlops.hpp:1483
Definition: amd_xdlops.hpp:1609
Definition: amd_xdlops.hpp:159
Definition: amd_xdlops.hpp:1546
Definition: amd_xdlops.hpp:1420
Definition: amd_xdlops.hpp:207
Definition: amd_xdlops.hpp:56
Definition: amd_xdlops.hpp:331
Definition: amd_xdlops.hpp:249
Definition: amd_xdlops.hpp:1451
Definition: amd_xdlops.hpp:1577
Definition: amd_xdlops.hpp:139
Definition: amd_xdlops.hpp:1514
Definition: amd_xdlops.hpp:1388
Definition: amd_xdlops.hpp:15
Definition: amd_xdlops.hpp:42
Definition: amd_xdlops.hpp:317
Definition: amd_xdlops.hpp:112
Definition: amd_xdlops.hpp:481
Definition: amd_xdlops.hpp:289
Definition: amd_xdlops.hpp:179
Definition: amd_xdlops.hpp:84
Definition: amd_xdlops.hpp:221
Definition: amd_xdlops.hpp:461
Definition: amd_xdlops.hpp:364
Definition: amd_xdlops.hpp:442
Definition: amd_xdlops.hpp:403
Definition: amd_xdlops.hpp:423
Definition: amd_xdlops.hpp:383
Definition: amd_xdlops.hpp:345
Definition: amd_xdlops.hpp:886
Definition: amd_xdlops.hpp:666
Definition: amd_wmma.hpp:297
Definition: amd_wmma.hpp:50
Definition: amd_wmma.hpp:418
Definition: amd_wmma.hpp:394
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
Definition: amd_wmma.hpp:370
Definition: amd_wmma.hpp:346
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:869
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:447
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:315
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:182
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:425
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:733
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:821
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:293
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:777
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:689
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:337
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:160
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:491
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:381
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:711
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:799
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:271
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:755
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:667
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:116
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:138
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:469
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:227
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:845
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:403
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:249
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:205
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:359
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:645
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:535
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:579
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:623
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:557
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:601
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:513
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:938
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:901
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:980
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1044
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1102
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1092
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:970
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1034
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1082
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1072
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:997
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1061
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition: xdlops_gemm.hpp:1008
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition: xdlops_gemm.hpp:1113
Definition: xdlops_gemm.hpp:952
static constexpr index_t n_per_blk
Definition: xdlops_gemm.hpp:961
static constexpr index_t group_size
Definition: xdlops_gemm.hpp:953
static constexpr index_t m_per_blk
Definition: xdlops_gemm.hpp:960
static constexpr bool is_k_reduction
Definition: xdlops_gemm.hpp:963
static constexpr index_t num_threads_per_blk
Definition: xdlops_gemm.hpp:956
static constexpr index_t num_output_blks
Definition: xdlops_gemm.hpp:959
static constexpr index_t wave_size
Definition: xdlops_gemm.hpp:957
static constexpr index_t num_input_blks
Definition: xdlops_gemm.hpp:958
static constexpr index_t num_groups_per_blk
Definition: xdlops_gemm.hpp:954
static constexpr index_t num_regs_per_blk
Definition: xdlops_gemm.hpp:955
static constexpr index_t k_per_blk
Definition: xdlops_gemm.hpp:962
Definition: xdlops_gemm.hpp:1016
static constexpr index_t n_per_blk
Definition: xdlops_gemm.hpp:1025
static constexpr index_t group_size
Definition: xdlops_gemm.hpp:1017
static constexpr index_t num_output_blks
Definition: xdlops_gemm.hpp:1023
static constexpr index_t m_per_blk
Definition: xdlops_gemm.hpp:1024
static constexpr index_t num_threads_per_blk
Definition: xdlops_gemm.hpp:1020
static constexpr bool is_k_reduction
Definition: xdlops_gemm.hpp:1027
static constexpr index_t num_regs_per_blk
Definition: xdlops_gemm.hpp:1019
static constexpr index_t num_groups_per_blk
Definition: xdlops_gemm.hpp:1018
static constexpr index_t num_input_blks
Definition: xdlops_gemm.hpp:1022
static constexpr index_t wave_size
Definition: xdlops_gemm.hpp:1021
static constexpr index_t k_per_blk
Definition: xdlops_gemm.hpp:1026
Definition: xdlops_gemm.hpp:98
Definition: functional2.hpp:33