/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File
xdlops_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math.hpp"
10 
11 namespace ck {
15 template <typename T>
16 static constexpr bool is_scale_mfma_data_type()
17 {
18  using U = element_type_t<T>;
19  return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
20  is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
21 }
22 
26 template <typename T>
27 static constexpr bool is_scale_mfma_scale_type()
28 {
29  return is_same_v<T, e8m0_bexp_t>;
30 }
31 
35 template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
36 static constexpr bool scale_mfma_hw_support()
37 {
38  return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
39  is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
40 }
41 
42 enum struct MfmaInstr
43 {
81  // gfx11
86  // gfx12
95 };
96 
97 template <MfmaInstr instr>
98 struct mfma_type;
99 
100 template <>
102 {
103  static constexpr index_t group_size = 4;
104  static constexpr index_t num_groups_per_blk = 4;
105  static constexpr index_t num_regs_per_blk = 16;
106  static constexpr index_t num_threads_per_blk = 32;
107  static constexpr index_t wave_size = 64;
108  static constexpr index_t num_input_blks = 2;
109  static constexpr index_t num_output_blks = 2;
110  static constexpr index_t m_per_blk = 32;
111  static constexpr index_t n_per_blk = 32;
112  static constexpr index_t k_per_blk = 1;
113  static constexpr bool is_k_reduction = false;
114 
115  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
116  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
117  {
119  }
120 };
121 
122 template <>
124 {
125  static constexpr index_t group_size = 4;
126  static constexpr index_t num_groups_per_blk = 4;
127  static constexpr index_t num_regs_per_blk = 16;
128  static constexpr index_t num_threads_per_blk = 32;
129  static constexpr index_t wave_size = 64;
130  static constexpr index_t num_input_blks = 2;
131  static constexpr index_t num_output_blks = 1;
132  static constexpr index_t m_per_blk = 32;
133  static constexpr index_t n_per_blk = 32;
134  static constexpr index_t k_per_blk = 1;
135  static constexpr bool is_k_reduction = true;
136 
137  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
138  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
139  {
141  }
142 };
143 
144 template <>
146 {
147  static constexpr index_t group_size = 4;
148  static constexpr index_t num_groups_per_blk = 1;
149  static constexpr index_t num_regs_per_blk = 4;
150  static constexpr index_t num_threads_per_blk = 16;
151  static constexpr index_t wave_size = 64;
152  static constexpr index_t num_input_blks = 4;
153  static constexpr index_t num_output_blks = 1;
154  static constexpr index_t m_per_blk = 16;
155  static constexpr index_t n_per_blk = 16;
156  static constexpr index_t k_per_blk = 1;
157  static constexpr bool is_k_reduction = true;
158 
159  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
160  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
161  {
163  }
164 };
165 
166 template <>
168 {
169  static constexpr index_t group_size = 4;
170  static constexpr index_t num_groups_per_blk = 1;
171  static constexpr index_t num_regs_per_blk = 4;
172  static constexpr index_t num_threads_per_blk = 16;
173  static constexpr index_t wave_size = 64;
174  static constexpr index_t num_input_blks = 4;
175  static constexpr index_t num_output_blks = 4;
176  static constexpr index_t m_per_blk = 16;
177  static constexpr index_t n_per_blk = 16;
178  static constexpr index_t k_per_blk = 1;
179  static constexpr bool is_k_reduction = false;
180 
181  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
182  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
183  {
185  }
186 };
187 
188 // treat 4x4x1 as a single-blk 4x64 mfma
189 template <>
191 {
192  static constexpr index_t group_size = 4;
193  static constexpr index_t num_groups_per_blk = 1;
194  static constexpr index_t num_regs_per_blk = 4;
195  static constexpr index_t num_threads_per_blk = 64;
196  static constexpr index_t wave_size = 64;
197  static constexpr index_t num_input_blks = 1;
198  static constexpr index_t num_output_blks = 1;
199  static constexpr index_t m_per_blk = 4;
200  static constexpr index_t n_per_blk = 64;
201  static constexpr index_t k_per_blk = 1;
202  static constexpr bool is_k_reduction = false;
203 
204  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
205  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
206  {
208  }
209 };
210 
211 template <>
213 {
214  static constexpr index_t group_size = 4;
215  static constexpr index_t num_groups_per_blk = 4;
216  static constexpr index_t num_regs_per_blk = 16;
217  static constexpr index_t num_threads_per_blk = 32;
218  static constexpr index_t wave_size = 64;
219  static constexpr index_t num_input_blks = 2;
220  static constexpr index_t num_output_blks = 2;
221  static constexpr index_t m_per_blk = 32;
222  static constexpr index_t n_per_blk = 32;
223  static constexpr index_t k_per_blk = 4;
224  static constexpr bool is_k_reduction = false;
225 
226  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
227  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
228  {
230  }
231 };
232 
233 template <>
235 {
236  static constexpr index_t group_size = 4;
237  static constexpr index_t num_groups_per_blk = 4;
238  static constexpr index_t num_regs_per_blk = 16;
239  static constexpr index_t num_threads_per_blk = 32;
240  static constexpr index_t wave_size = 64;
241  static constexpr index_t num_input_blks = 2;
242  static constexpr index_t num_output_blks = 1;
243  static constexpr index_t m_per_blk = 32;
244  static constexpr index_t n_per_blk = 32;
245  static constexpr index_t k_per_blk = 4;
246  static constexpr bool is_k_reduction = true;
247 
248  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
249  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
250  {
252  }
253 };
254 
255 template <>
257 {
258  static constexpr index_t group_size = 4;
259  static constexpr index_t num_groups_per_blk = 4;
260  static constexpr index_t num_regs_per_blk = 16;
261  static constexpr index_t num_threads_per_blk = 32;
262  static constexpr index_t wave_size = 64;
263  static constexpr index_t num_input_blks = 2;
264  static constexpr index_t num_output_blks = 1;
265  static constexpr index_t m_per_blk = 32;
266  static constexpr index_t n_per_blk = 32;
267  static constexpr index_t k_per_blk = 8;
268  static constexpr bool is_k_reduction = true;
269 
270  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
271  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
272  {
274  }
275 };
276 
277 template <>
279 {
280  static constexpr index_t group_size = 4;
281  static constexpr index_t num_groups_per_blk = 1;
282  static constexpr index_t num_regs_per_blk = 4;
283  static constexpr index_t num_threads_per_blk = 16;
284  static constexpr index_t wave_size = 64;
285  static constexpr index_t num_input_blks = 4;
286  static constexpr index_t num_output_blks = 1;
287  static constexpr index_t m_per_blk = 16;
288  static constexpr index_t n_per_blk = 16;
289  static constexpr index_t k_per_blk = 8;
290  static constexpr bool is_k_reduction = true;
291 
292  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
293  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
294  {
296  }
297 };
298 
299 template <>
301 {
302  static constexpr index_t group_size = 4;
303  static constexpr index_t num_groups_per_blk = 1;
304  static constexpr index_t num_regs_per_blk = 4;
305  static constexpr index_t num_threads_per_blk = 16;
306  static constexpr index_t wave_size = 64;
307  static constexpr index_t num_input_blks = 4;
308  static constexpr index_t num_output_blks = 1;
309  static constexpr index_t m_per_blk = 16;
310  static constexpr index_t n_per_blk = 16;
311  static constexpr index_t k_per_blk = 4;
312  static constexpr bool is_k_reduction = true;
313 
314  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
315  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
316  {
318  }
319 };
320 
321 template <>
323 {
324  static constexpr index_t group_size = 4;
325  static constexpr index_t num_groups_per_blk = 1;
326  static constexpr index_t num_regs_per_blk = 4;
327  static constexpr index_t num_threads_per_blk = 16;
328  static constexpr index_t wave_size = 64;
329  static constexpr index_t num_input_blks = 4;
330  static constexpr index_t num_output_blks = 4;
331  static constexpr index_t m_per_blk = 16;
332  static constexpr index_t n_per_blk = 16;
333  static constexpr index_t k_per_blk = 4;
334  static constexpr bool is_k_reduction = false;
335 
336  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
337  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
338  {
340  }
341 };
342 
343 template <>
345 {
346  static constexpr index_t group_size = 4;
347  static constexpr index_t num_groups_per_blk = 1;
348  static constexpr index_t num_regs_per_blk = 4;
349  static constexpr index_t num_threads_per_blk = 64;
350  static constexpr index_t wave_size = 64;
351  static constexpr index_t num_input_blks = 1;
352  static constexpr index_t num_output_blks = 1;
353  static constexpr index_t m_per_blk = 4;
354  static constexpr index_t n_per_blk = 64;
355  static constexpr index_t k_per_blk = 4;
356  static constexpr bool is_k_reduction = false;
357 
358  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
359  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
360  {
362  }
363 };
364 
365 template <>
367 {
368  static constexpr index_t group_size = 4;
369  static constexpr index_t num_groups_per_blk = 4;
370  static constexpr index_t num_regs_per_blk = 16;
371  static constexpr index_t num_threads_per_blk = 32;
372  static constexpr index_t wave_size = 64;
373  static constexpr index_t num_input_blks = 2;
374  static constexpr index_t num_output_blks = 1;
375  static constexpr index_t m_per_blk = 32;
376  static constexpr index_t n_per_blk = 32;
377  static constexpr index_t k_per_blk = 8;
378  static constexpr bool is_k_reduction = true;
379 
380  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
381  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
382  {
384  }
385 };
386 
387 template <>
389 {
390  static constexpr index_t group_size = 4;
391  static constexpr index_t num_groups_per_blk = 4;
392  static constexpr index_t num_regs_per_blk = 16;
393  static constexpr index_t num_threads_per_blk = 32;
394  static constexpr index_t wave_size = 64;
395  static constexpr index_t num_input_blks = 2;
396  static constexpr index_t num_output_blks = 1;
397  static constexpr index_t m_per_blk = 32;
398  static constexpr index_t n_per_blk = 32;
399  static constexpr index_t k_per_blk = 4;
400  static constexpr bool is_k_reduction = true;
401 
402  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
403  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
404  {
406  }
407 };
408 
409 template <>
411 {
412  static constexpr index_t group_size = 4;
413  static constexpr index_t num_groups_per_blk = 1;
414  static constexpr index_t num_regs_per_blk = 4;
415  static constexpr index_t num_threads_per_blk = 16;
416  static constexpr index_t wave_size = 64;
417  static constexpr index_t num_input_blks = 4;
418  static constexpr index_t num_output_blks = 1;
419  static constexpr index_t m_per_blk = 16;
420  static constexpr index_t n_per_blk = 16;
421  static constexpr index_t k_per_blk = 8;
422  static constexpr bool is_k_reduction = true;
423 
424  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
425  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
426  {
428  }
429 };
430 
431 template <>
433 {
434  static constexpr index_t group_size = 4;
435  static constexpr index_t num_groups_per_blk = 1;
436  static constexpr index_t num_regs_per_blk = 4;
437  static constexpr index_t num_threads_per_blk = 16;
438  static constexpr index_t wave_size = 64;
439  static constexpr index_t num_input_blks = 4;
440  static constexpr index_t num_output_blks = 1;
441  static constexpr index_t m_per_blk = 16;
442  static constexpr index_t n_per_blk = 16;
443  static constexpr index_t k_per_blk = 4;
444  static constexpr bool is_k_reduction = true;
445 
446  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
447  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
448  {
450  }
451 };
452 
453 template <>
455 {
456  static constexpr index_t group_size = 4;
457  static constexpr index_t num_groups_per_blk = 4;
458  static constexpr index_t num_regs_per_blk = 16;
459  static constexpr index_t num_threads_per_blk = 32;
460  static constexpr index_t wave_size = 64;
461  static constexpr index_t num_input_blks = 2;
462  static constexpr index_t num_output_blks = 1;
463  static constexpr index_t m_per_blk = 32;
464  static constexpr index_t n_per_blk = 32;
465  static constexpr index_t k_per_blk = 2;
466  static constexpr bool is_k_reduction = true;
467 
468  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
469  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
470  {
472  }
473 };
474 
475 template <>
477 {
478  static constexpr index_t group_size = 4;
479  static constexpr index_t num_groups_per_blk = 1;
480  static constexpr index_t num_regs_per_blk = 4;
481  static constexpr index_t num_threads_per_blk = 16;
482  static constexpr index_t wave_size = 64;
483  static constexpr index_t num_input_blks = 4;
484  static constexpr index_t num_output_blks = 1;
485  static constexpr index_t m_per_blk = 16;
486  static constexpr index_t n_per_blk = 16;
487  static constexpr index_t k_per_blk = 2;
488  static constexpr bool is_k_reduction = true;
489 
490  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
491  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
492  {
494  }
495 };
496 
497 template <>
499 {
500  static constexpr index_t group_size = 4;
501  static constexpr index_t num_groups_per_blk = 4;
502  static constexpr index_t num_regs_per_blk = 16;
503  static constexpr index_t num_threads_per_blk = 32;
504  static constexpr index_t wave_size = 64;
505  static constexpr index_t num_input_blks = 2;
506  static constexpr index_t num_output_blks = 1;
507  static constexpr index_t m_per_blk = 32;
508  static constexpr index_t n_per_blk = 32;
509  static constexpr index_t k_per_blk = 4;
510  static constexpr bool is_k_reduction = true;
511 
512  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
513  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
514  {
516  }
517 };
518 
519 template <>
521 {
522  static constexpr index_t group_size = 4;
523  static constexpr index_t num_groups_per_blk = 1;
524  static constexpr index_t num_regs_per_blk = 4;
525  static constexpr index_t num_threads_per_blk = 16;
526  static constexpr index_t wave_size = 64;
527  static constexpr index_t num_input_blks = 4;
528  static constexpr index_t num_output_blks = 1;
529  static constexpr index_t m_per_blk = 16;
530  static constexpr index_t n_per_blk = 16;
531  static constexpr index_t k_per_blk = 4;
532  static constexpr bool is_k_reduction = true;
533 
534  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
535  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
536  {
538  }
539 };
540 
541 template <>
543 {
544  static constexpr index_t group_size = 4;
545  static constexpr index_t num_groups_per_blk = 4;
546  static constexpr index_t num_regs_per_blk = 16;
547  static constexpr index_t num_threads_per_blk = 32;
548  static constexpr index_t wave_size = 64;
549  static constexpr index_t num_input_blks = 2;
550  static constexpr index_t num_output_blks = 1;
551  static constexpr index_t m_per_blk = 32;
552  static constexpr index_t n_per_blk = 32;
553  static constexpr index_t k_per_blk = 8;
554  static constexpr bool is_k_reduction = true;
555 
556  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
557  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
558  {
560  }
561 };
562 
563 template <>
565 {
566  static constexpr index_t group_size = 4;
567  static constexpr index_t num_groups_per_blk = 1;
568  static constexpr index_t num_regs_per_blk = 4;
569  static constexpr index_t num_threads_per_blk = 16;
570  static constexpr index_t wave_size = 64;
571  static constexpr index_t num_input_blks = 4;
572  static constexpr index_t num_output_blks = 1;
573  static constexpr index_t m_per_blk = 16;
574  static constexpr index_t n_per_blk = 16;
575  static constexpr index_t k_per_blk = 8;
576  static constexpr bool is_k_reduction = true;
577 
578  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
579  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
580  {
582  }
583 };
584 
585 template <>
587 {
588  static constexpr index_t group_size = 4;
589  static constexpr index_t num_groups_per_blk = 4;
590  static constexpr index_t num_regs_per_blk = 16;
591  static constexpr index_t num_threads_per_blk = 32;
592  static constexpr index_t wave_size = 64;
593  static constexpr index_t num_input_blks = 2;
594  static constexpr index_t num_output_blks = 1;
595  static constexpr index_t m_per_blk = 32;
596  static constexpr index_t n_per_blk = 32;
597  static constexpr index_t k_per_blk = 16;
598  static constexpr bool is_k_reduction = true;
599 
600  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
601  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
602  {
604  }
605 };
606 
607 template <>
609 {
610  static constexpr index_t group_size = 4;
611  static constexpr index_t num_groups_per_blk = 1;
612  static constexpr index_t num_regs_per_blk = 4;
613  static constexpr index_t num_threads_per_blk = 16;
614  static constexpr index_t wave_size = 64;
615  static constexpr index_t num_input_blks = 4;
616  static constexpr index_t num_output_blks = 1;
617  static constexpr index_t m_per_blk = 16;
618  static constexpr index_t n_per_blk = 16;
619  static constexpr index_t k_per_blk = 16;
620  static constexpr bool is_k_reduction = true;
621 
622  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
623  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
624  {
626  }
627 };
628 
629 template <>
631 {
632  static constexpr index_t group_size = 1;
633  static constexpr index_t num_groups_per_blk = 4;
634  static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk;
635  static constexpr index_t num_threads_per_blk = 16;
636  static constexpr index_t wave_size = 64;
637  static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk;
638  static constexpr index_t num_output_blks = 1;
639  static constexpr index_t m_per_blk = 16;
640  static constexpr index_t n_per_blk = 16;
641  static constexpr index_t k_per_blk = 1;
642  static constexpr bool is_k_reduction = true;
643 
644  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
645  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
646  {
648  }
649 };
650 
651 template <>
653 {
654  static constexpr index_t group_size = 4;
655  static constexpr index_t num_groups_per_blk = 4;
656  static constexpr index_t num_regs_per_blk = 16;
657  static constexpr index_t num_threads_per_blk = 32;
658  static constexpr index_t wave_size = 64;
659  static constexpr index_t num_input_blks = 2;
660  static constexpr index_t num_output_blks = 1;
661  static constexpr index_t m_per_blk = 32;
662  static constexpr index_t n_per_blk = 32;
663  static constexpr index_t k_per_blk = 8;
664  static constexpr bool is_k_reduction = true;
665 
666  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
667  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
668  {
670  }
671 };
672 
673 template <>
675 {
676  static constexpr index_t group_size = 4;
677  static constexpr index_t num_groups_per_blk = 1;
678  static constexpr index_t num_regs_per_blk = 4;
679  static constexpr index_t num_threads_per_blk = 16;
680  static constexpr index_t wave_size = 64;
681  static constexpr index_t num_input_blks = 4;
682  static constexpr index_t num_output_blks = 1;
683  static constexpr index_t m_per_blk = 16;
684  static constexpr index_t n_per_blk = 16;
685  static constexpr index_t k_per_blk = 8;
686  static constexpr bool is_k_reduction = true;
687 
688  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
689  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
690  {
692  }
693 };
694 
695 template <>
697 {
698  static constexpr index_t group_size = 4;
699  static constexpr index_t num_groups_per_blk = 4;
700  static constexpr index_t num_regs_per_blk = 16;
701  static constexpr index_t num_threads_per_blk = 32;
702  static constexpr index_t wave_size = 64;
703  static constexpr index_t num_input_blks = 2;
704  static constexpr index_t num_output_blks = 1;
705  static constexpr index_t m_per_blk = 32;
706  static constexpr index_t n_per_blk = 32;
707  static constexpr index_t k_per_blk = 8;
708  static constexpr bool is_k_reduction = true;
709 
710  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
711  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
712  {
714  }
715 };
716 
717 template <>
719 {
720  static constexpr index_t group_size = 4;
721  static constexpr index_t num_groups_per_blk = 1;
722  static constexpr index_t num_regs_per_blk = 4;
723  static constexpr index_t num_threads_per_blk = 16;
724  static constexpr index_t wave_size = 64;
725  static constexpr index_t num_input_blks = 4;
726  static constexpr index_t num_output_blks = 1;
727  static constexpr index_t m_per_blk = 16;
728  static constexpr index_t n_per_blk = 16;
729  static constexpr index_t k_per_blk = 8;
730  static constexpr bool is_k_reduction = true;
731 
732  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
733  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
734  {
736  }
737 };
738 
739 template <>
741 {
742  static constexpr index_t group_size = 4;
743  static constexpr index_t num_groups_per_blk = 4;
744  static constexpr index_t num_regs_per_blk = 16;
745  static constexpr index_t num_threads_per_blk = 32;
746  static constexpr index_t wave_size = 64;
747  static constexpr index_t num_input_blks = 2;
748  static constexpr index_t num_output_blks = 1;
749  static constexpr index_t m_per_blk = 32;
750  static constexpr index_t n_per_blk = 32;
751  static constexpr index_t k_per_blk = 8;
752  static constexpr bool is_k_reduction = true;
753 
754  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
755  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
756  {
758  }
759 };
760 
761 template <>
763 {
764  static constexpr index_t group_size = 4;
765  static constexpr index_t num_groups_per_blk = 1;
766  static constexpr index_t num_regs_per_blk = 4;
767  static constexpr index_t num_threads_per_blk = 16;
768  static constexpr index_t wave_size = 64;
769  static constexpr index_t num_input_blks = 4;
770  static constexpr index_t num_output_blks = 1;
771  static constexpr index_t m_per_blk = 16;
772  static constexpr index_t n_per_blk = 16;
773  static constexpr index_t k_per_blk = 8;
774  static constexpr bool is_k_reduction = true;
775 
776  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
777  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
778  {
780  }
781 };
782 
783 template <>
785 {
786  static constexpr index_t group_size = 4;
787  static constexpr index_t num_groups_per_blk = 4;
788  static constexpr index_t num_regs_per_blk = 16;
789  static constexpr index_t num_threads_per_blk = 32;
790  static constexpr index_t wave_size = 64;
791  static constexpr index_t num_input_blks = 2;
792  static constexpr index_t num_output_blks = 1;
793  static constexpr index_t m_per_blk = 32;
794  static constexpr index_t n_per_blk = 32;
795  static constexpr index_t k_per_blk = 8;
796  static constexpr bool is_k_reduction = true;
797 
798  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
799  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
800  {
802  }
803 };
804 
805 template <>
807 {
808  static constexpr index_t group_size = 4;
809  static constexpr index_t num_groups_per_blk = 1;
810  static constexpr index_t num_regs_per_blk = 4;
811  static constexpr index_t num_threads_per_blk = 16;
812  static constexpr index_t wave_size = 64;
813  static constexpr index_t num_input_blks = 4;
814  static constexpr index_t num_output_blks = 1;
815  static constexpr index_t m_per_blk = 16;
816  static constexpr index_t n_per_blk = 16;
817  static constexpr index_t k_per_blk = 8;
818  static constexpr bool is_k_reduction = true;
819 
820  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
821  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
822  {
824  }
825 };
826 
827 template <>
829 {
830  // clang-format off
831  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
832  static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
833  static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
834  static constexpr index_t num_threads_per_blk = 32; // n_per_blk
835  static constexpr index_t wave_size = 64; // fixed
836  static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
837  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
838  static constexpr index_t m_per_blk = 32; // from the instruction
839  static constexpr index_t n_per_blk = 32; // from the instruction
840  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
841  static constexpr bool is_k_reduction = true; // ???
842  // clang-format on
843 
844  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
845  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
846  {
848  }
849 };
850 
851 template <>
853 {
854  // clang-format off
855  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
856  static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
857  static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
858  static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
859  static constexpr index_t wave_size = 64; // fixed
860  static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
861  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
862  static constexpr index_t m_per_blk = 16; // from the instruction
863  static constexpr index_t n_per_blk = 16; // from the instruction
864  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
865  static constexpr bool is_k_reduction = true; // ???
866  // clang-format on
867 
868  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
869  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
870  {
872  }
873 };
874 
875 template <>
877 {
878  // clang-format off
879  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
880  static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
881  static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
882  static constexpr index_t num_threads_per_blk = 32; // n_per_blk
883  static constexpr index_t wave_size = 64; // fixed
884  static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
885  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
886  static constexpr index_t m_per_blk = 32; // from the instruction
887  static constexpr index_t n_per_blk = 32; // from the instruction
888  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
889  static constexpr bool is_k_reduction = true; // ???
890  // clang-format on
891 
892  template <index_t MPerXdlops,
893  index_t NPerXdlops,
894  index_t OpselA,
895  index_t OpselB,
896  class FloatA,
897  class ScaleA,
898  class FloatB,
899  class ScaleB,
900  class FloatC>
901  __device__ void run(const FloatA& a,
902  const ScaleA& scale_a,
903  const FloatB& b,
904  const ScaleB& scale_b,
905  FloatC& reg_c) const
906  {
908  a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
909  }
910 };
911 
912 template <>
914 {
915  // clang-format off
916  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
917  static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
918  static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
919  static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
920  static constexpr index_t wave_size = 64; // fixed
921  static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
922  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
923  static constexpr index_t m_per_blk = 16; // from the instruction
924  static constexpr index_t n_per_blk = 16; // from the instruction
925  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
926  static constexpr bool is_k_reduction = true; // ???
927  // clang-format on
928 
929  template <index_t MPerXdlops,
930  index_t NPerXdlops,
931  index_t OpselA,
932  index_t OpselB,
933  class FloatA,
934  class ScaleA,
935  class FloatB,
936  class ScaleB,
937  class FloatC>
938  __device__ void run(const FloatA& a,
939  const ScaleA& scale_a,
940  const FloatB& b,
941  const ScaleB& scale_b,
942  FloatC& reg_c) const
943  {
944 
946  a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
947  }
948 };
949 
950 // gfx11
952 {
953  static constexpr index_t group_size = 8;
954  static constexpr index_t num_groups_per_blk = 1;
955  static constexpr index_t num_regs_per_blk = 8;
956  static constexpr index_t num_threads_per_blk = 16;
957  static constexpr index_t wave_size = 32;
958  static constexpr index_t num_input_blks = 1;
959  static constexpr index_t num_output_blks = 1;
960  static constexpr index_t m_per_blk = 16;
961  static constexpr index_t n_per_blk = 16;
962  static constexpr index_t k_per_blk = 16;
963  static constexpr bool is_k_reduction = true;
964 };
965 
966 template <>
968 {
969  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
970  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
971  {
973  }
974 };
975 
976 template <>
978 {
979  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
980  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
981  {
983  }
984 };
985 
986 template <>
988 {
989  template <index_t MPerWmma,
990  index_t NPerWmma,
991  class FloatA,
992  class FloatB,
993  class FloatC,
994  bool neg_a = true,
995  bool neg_b = true,
996  bool clamp = false>
997  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
998  {
1000  }
1001 };
1002 
1003 template <>
1005 {
1006  static constexpr index_t k_per_blk = 2;
1007  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1008  __device__ void run(const FloatA&, const FloatB&, FloatC&) const
1009  {
1010  // empty for all unsupported types.
1011  }
1012 };
1013 
1014 // gfx12
1016 {
1017  static constexpr index_t group_size = 8;
1018  static constexpr index_t num_groups_per_blk = 1;
1019  static constexpr index_t num_regs_per_blk = 8;
1020  static constexpr index_t num_threads_per_blk = 16;
1021  static constexpr index_t wave_size = 32;
1022  static constexpr index_t num_input_blks = 2;
1023  static constexpr index_t num_output_blks = 1;
1024  static constexpr index_t m_per_blk = 16;
1025  static constexpr index_t n_per_blk = 16;
1026  static constexpr index_t k_per_blk = 8;
1027  static constexpr bool is_k_reduction = true;
1028 };
1029 
1030 template <>
1032 {
1033  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1034  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1035  {
1037  }
1038 };
1039 
1040 template <>
1042 {
1043  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1044  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1045  {
1047  }
1048 };
1049 
1050 template <>
1052 {
1053  template <index_t MPerWmma,
1054  index_t NPerWmma,
1055  class FloatA,
1056  class FloatB,
1057  class FloatC,
1058  bool neg_a = true,
1059  bool neg_b = true,
1060  bool clamp = false>
1061  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1062  {
1064  a, b, reg_c);
1065  }
1066 };
1067 
1068 template <>
1070 {
1071  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1072  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1073  {
1075  }
1076 };
1077 
1078 template <>
1080 {
1081  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1082  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1083  {
1085  }
1086 };
1087 
1088 template <>
1090 {
1091  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1092  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1093  {
1095  }
1096 };
1097 
1098 template <>
1100 {
1101  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1102  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1103  {
1105  }
1106 };
1107 
1108 template <>
1110 {
1111  static constexpr index_t k_per_blk = 2;
1112  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1113  __device__ void run(const FloatA&, const FloatB&, FloatC&) const
1114  {
1115  // empty for all unsupported types.
1116  }
1117 };
1118 
1119 template <typename base_type,
1120  index_t MPerXdlops,
1121  index_t NPerXdlops,
1122  typename additional_type = base_type,
1123  bool is_single_rate_mfma = false,
1124  bool is_scale_mfma = false>
1126 {
1127  template <typename base_type_,
1128  index_t MPerXdlops_,
1129  index_t NPerXdlops_,
1130  typename additional_type_ = base_type_,
1131  bool is_single_rate_mfma_ = false,
1132  bool is_scale_mfma_ = false>
1133  static constexpr auto GetMfma();
1134 
1135  template <>
1136  constexpr auto GetMfma<double, 16, 16>()
1137  {
1138 #if defined(__gfx12__)
1140 #elif defined(__gfx11__)
1142 #else
1144 #endif
1145  }
1146 
1147  template <>
1148  constexpr auto GetMfma<float, 64, 64>()
1149  {
1151  }
1152 
1153  template <>
1154  constexpr auto GetMfma<float, 32, 64>()
1155  {
1157  }
1158 
1159  template <>
1160  constexpr auto GetMfma<float, 16, 64>()
1161  {
1163  }
1164 
1165  template <>
1166  constexpr auto GetMfma<float, 8, 64>()
1167  {
1169  }
1170 
1171  template <>
1172  constexpr auto GetMfma<float, 4, 64>()
1173  {
1175  }
1176 
1177  template <>
1178  constexpr auto GetMfma<float, 32, 32>()
1179  {
1181  }
1182 
1183  template <>
1184  constexpr auto GetMfma<float, 16, 16>()
1185  {
1186 #if defined(__gfx12__)
1188 #elif defined(__gfx11__)
1190 #else
1192 #endif
1193  }
1194 
1195  template <>
1196  constexpr auto GetMfma<half_t, 64, 64>()
1197  {
1199  }
1200 
1201  template <>
1202  constexpr auto GetMfma<half_t, 32, 64>()
1203  {
1205  }
1206 
1207  template <>
1208  constexpr auto GetMfma<half_t, 32, 32, half_t, false>()
1209  {
1210 #if defined(__gfx950__)
1212 #else
1214 #endif
1215  }
1216  template <>
1217  constexpr auto GetMfma<half_t, 32, 32, half_t, true>()
1218  {
1220  }
1221 
1222  template <>
1223  constexpr auto GetMfma<half_t, 16, 16, half_t, false>()
1224  {
1225 #if defined(__gfx12__)
1227 #elif defined(__gfx11__)
1229 #elif defined(__gfx950__)
1231 #else
1233 #endif
1234  }
1235 
1236  template <>
1237  constexpr auto GetMfma<half_t, 16, 16, half_t, true>()
1238  {
1239 #if defined(__gfx12__)
1241 #elif defined(__gfx11__)
1243 #else
1245 #endif
1246  }
1247 
1248  template <>
1249  constexpr auto GetMfma<half_t, 16, 64>()
1250  {
1252  }
1253 
1254  template <>
1255  constexpr auto GetMfma<half_t, 8, 64>()
1256  {
1258  }
1259 
1260  template <>
1261  constexpr auto GetMfma<half_t, 4, 64>()
1262  {
1264  }
1265 
1266  template <>
1267  constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
1268  {
1269 #if defined(__gfx950__)
1271 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1273 #else
1275 #endif
1276  }
1277 
1278  template <>
1279  constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
1280  {
1281 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1283 #else
1285 #endif
1286  }
1287 
1288  template <>
1289  constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
1290  {
1291 #if defined(__gfx12__)
1293 #elif defined(__gfx11__)
1295 #elif defined(__gfx950__)
1297 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1299 #else
1301 #endif
1302  }
1303 
1304  template <>
1305  constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
1306  {
1307 #if defined(__gfx12__)
1309 #elif defined(__gfx11__)
1311 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1313 #else
1315 #endif
1316  }
1317 
1318  template <>
1319  constexpr auto GetMfma<int8_t, 32, 32, int8_t, false>()
1320  {
1321 #if defined(__gfx950__)
1323 #elif defined(__gfx942__)
1325 #else
1327 #endif
1328  }
1329 
1330  template <>
1331  constexpr auto GetMfma<int8_t, 32, 32, int8_t, true>()
1332  {
1333 #if defined(__gfx942__) || defined(__gfx950__)
1335 #else
1337 #endif
1338  }
1339 
1340  template <>
1341  constexpr auto GetMfma<int8_t, 16, 16, int8_t, false>()
1342  {
1343 #if defined(__gfx12__)
1345 #elif defined(__gfx11__)
1347 #elif defined(__gfx950__)
1349 #elif defined(__gfx942__)
1351 #else
1353 #endif
1354  }
1355 
1356  template <>
1357  constexpr auto GetMfma<int8_t, 16, 16, int8_t, true>()
1358  {
1359 #if defined(__gfx12__)
1361 #elif defined(__gfx11__)
1363 #elif defined(__gfx942__) || defined(__gfx950__)
1365 #else
1367 #endif
1368  }
1369 
1370  template <>
1371  constexpr auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
1372  {
1374  }
1375 
1376  template <>
1377  constexpr auto GetMfma<f8_t, 32, 32, f8_t, false, false>()
1378  {
1379 #if defined(__gfx950__)
1381 #else
1383 #endif
1384  }
1385 
1386  template <>
1387  constexpr auto GetMfma<f8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
1388  {
1390  }
1391 
1392  template <>
1393  constexpr auto GetMfma<bf8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
1394  {
1396  }
1397  template <>
1398  constexpr auto GetMfma<f4_t, 32, 32, f4_t, is_single_rate_mfma, true>()
1399  {
1401  }
1402  template <>
1403  constexpr auto GetMfma<f4_t, 16, 16, f4_t, is_single_rate_mfma, true>()
1404  {
1405 #if defined(__gfx12__)
1407 #elif defined(__gfx11__)
1409 #else
1411 #endif
1412  }
1413 
1414  template <>
1415  constexpr auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
1416  {
1417 #if defined(__gfx12__)
1419 #elif defined(__gfx11__)
1421 #else
1423 #endif
1424  }
1425 
1426  template <>
1427  constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
1428  {
1429 #if defined(__gfx12__)
1431 #elif defined(__gfx11__)
1433 #elif defined(__gfx950__)
1435 #else
1437 #endif
1438  }
1439 
1440  template <>
1441  constexpr auto GetMfma<f8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
1442  {
1443 #if defined(__gfx12__)
1445 #elif defined(__gfx11__)
1447 #else
1449 #endif
1450  }
1451 
1452  template <>
1453  constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
1454  {
1455 #if defined(__gfx12__)
1457 #elif defined(__gfx11__)
1459 #else
1461 #endif
1462  }
1463 
1464  template <>
1465  constexpr auto GetMfma<f8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
1466  {
1467 #if defined(__gfx12__)
1469 #elif defined(__gfx11__)
1471 #else
1473 #endif
1474  }
1475 
1476  template <>
1477  constexpr auto GetMfma<bf8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
1478  {
1479 #if defined(__gfx12__)
1481 #elif defined(__gfx11__)
1483 #else
1485 #endif
1486  }
1487 
1488  template <>
1489  constexpr auto GetMfma<f6_t, 32, 32, f6_t, is_single_rate_mfma, true>()
1490  {
1492  }
1493  template <>
1494  constexpr auto GetMfma<f6_t, 16, 16, f6_t, is_single_rate_mfma, true>()
1495  {
1496 #if defined(__gfx12__)
1498 #elif defined(__gfx11__)
1500 #else
1502 #endif
1503  }
1504  template <>
1505  constexpr auto GetMfma<bf6_t, 32, 32, bf6_t, is_single_rate_mfma, true>()
1506  {
1508  }
1509  template <>
1510  constexpr auto GetMfma<bf6_t, 16, 16, bf6_t, is_single_rate_mfma, true>()
1511  {
1512 #if defined(__gfx12__)
1514 #elif defined(__gfx11__)
1516 #else
1518 #endif
1519  }
1520 
1521  template <>
1522  constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
1523  {
1525  }
1526 
1527  template <>
1528  constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, false, false>()
1529  {
1530 #if defined(__gfx950__)
1532 #else
1534 #endif
1535  }
1536 
1537  template <>
1538  constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
1539  {
1540 #if defined(__gfx12__)
1542 #elif defined(__gfx11__)
1544 #else
1546 #endif
1547  }
1548 
1549  template <>
1550  constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
1551  {
1552 #if defined(__gfx12__)
1554 #elif defined(__gfx11__)
1556 #elif defined(__gfx950__)
1558 #else
1560 #endif
1561  }
1562 
1563  template <>
1564  constexpr auto GetMfma<f8_t, 32, 32, bf8_t, true, false>()
1565  {
1567  }
1568 
1569  template <>
1570  constexpr auto GetMfma<f8_t, 32, 32, bf8_t, false, false>()
1571  {
1572 #if defined(__gfx950__)
1574 #else
1576 #endif
1577  }
1578 
1579  template <>
1580  constexpr auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
1581  {
1582 #if defined(__gfx12__)
1584 #elif defined(__gfx11__)
1586 #else
1588 #endif
1589  }
1590 
1591  template <>
1592  constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
1593  {
1594 #if defined(__gfx12__)
1596 #elif defined(__gfx11__)
1598 #elif defined(__gfx950__)
1600 #else
1602 #endif
1603  }
1604 
1605  template <>
1606  constexpr auto GetMfma<bf8_t, 32, 32, f8_t, true, false>()
1607  {
1609  }
1610 
1611  template <>
1612  constexpr auto GetMfma<bf8_t, 32, 32, f8_t, false, false>()
1613  {
1614 #if defined(__gfx950__)
1616 #else
1618 #endif
1619  }
1620 
1621  template <>
1622  constexpr auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
1623  {
1624 #if defined(__gfx12__)
1626 #elif defined(__gfx11__)
1628 #else
1630 #endif
1631  }
1632 
1633  template <>
1634  constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
1635  {
1636 #if defined(__gfx12__)
1638 #elif defined(__gfx11__)
1640 #elif defined(__gfx950__)
1642 #else
1644 #endif
1645  }
1646 
1648  MPerXdlops,
1649  NPerXdlops,
1651  is_single_rate_mfma,
1652  is_scale_mfma>()>{};
1653 
1654  __host__ __device__ constexpr MfmaSelector()
1655  {
1656  static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
1657  selected_mfma.num_regs_per_blk,
1658  "wrong! num_regs_per_blk");
1659 
1660  static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk,
1661  "n_per_blk != num_threads_per_blk");
1662 #if defined(__gfx11__)
1663  if constexpr(MPerXdlops == 16 && NPerXdlops == 16)
1664  {
1665  static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks * 2 ==
1666  selected_mfma.m_per_blk,
1667  "m_per_blk != num_input_blks * num_regs_per_blk");
1668  }
1669 #else
1670  static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks ==
1671  selected_mfma.m_per_blk,
1672  "m_per_blk != num_input_blks * num_regs_per_blk");
1673 #endif
1674 
1675  static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks ||
1676  selected_mfma.num_output_blks == 1,
1677  "incorrect num_output_blks");
1678 
1679  static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size ==
1680  selected_mfma.m_per_blk * selected_mfma.n_per_blk,
1681  "num_regs_per_blk incorrect");
1682 
1683  static_assert(selected_mfma.is_k_reduction ||
1684  (selected_mfma.num_input_blks == selected_mfma.num_output_blks),
1685  "is_k_reduction wrong!");
1686  }
1687 
1688  static constexpr bool IsABroadcast()
1689  {
1690  static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
1691  return true;
1692  }
1693 
1694  static constexpr index_t GetKPerXdlops()
1695  {
1696  return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) *
1697  selected_mfma.k_per_blk;
1698  }
1699 
1700  static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; }
1701 };
1702 
1703 template <typename base_type,
1704  index_t MPerXdlops,
1705  index_t NPerXdlops,
1706  index_t KPack,
1707  typename additional_type = base_type,
1708  bool TransposeC = false,
1709  bool is_scale_mfma = false>
1711 {
1712  static constexpr auto I0 = Number<0>{};
1713  static constexpr auto I1 = Number<1>{};
1714  static constexpr auto I2 = Number<2>{};
1715  static constexpr auto I3 = Number<3>{};
1716  static constexpr auto I4 = Number<4>{};
1717  static constexpr auto I5 = Number<5>{};
1718 
1721 
1722  __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; }
1723 
1724  __device__ static constexpr index_t GetNumXdlops()
1725  {
1726  return MPerXdlops * NPerXdlops /
1727  (mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks);
1728  }
1729 
1730  __host__ __device__ constexpr XdlopsGemm()
1731  {
1732  static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
1733  NPerXdlops == 64,
1734  "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1735 
1736  static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
1737  MPerXdlops == 64,
1738  "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1739 #if defined(__HIP_DEVICE_COMPILE__)
1740  static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
1741 #endif
1742  }
1743 
1744  // XDL output supporting C = A * B
1745  // M2_N2 -> M2_M3_M4_N2
1746  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1747  __host__ __device__ static constexpr auto
1748  MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1749  {
1750  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1751  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1752  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1753  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1754  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1755 
1757  c_desc_m0_n0_m1_n1_m2_n2,
1763  Number<num_blks>{},
1764  Number<mfma_instr.group_size>{})),
1767  Sequence<1>{},
1768  Sequence<2>{},
1769  Sequence<3>{},
1770  Sequence<4>{},
1771  Sequence<5>{}),
1773  Sequence<1>{},
1774  Sequence<2>{},
1775  Sequence<3>{},
1777  Sequence<7>{}));
1778  }
1779 
1780  // XDL output supporting C = A * B
1781  // M3_N3 -> M3_M4_M5_N3
1782  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1783  __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
1784  const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1785  {
1786  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1787  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1788  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1789  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1790  const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
1791  const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
1792  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1793 
1795  c_desc_m0_n0_m1_n1_m2_n2,
1803  Number<num_blks>{},
1804  Number<mfma_instr.group_size>{})),
1807  Sequence<1>{},
1808  Sequence<2>{},
1809  Sequence<3>{},
1810  Sequence<4>{},
1811  Sequence<5>{},
1812  Sequence<6>{},
1813  Sequence<7>{}),
1815  Sequence<1>{},
1816  Sequence<2>{},
1817  Sequence<3>{},
1818  Sequence<4>{},
1819  Sequence<5>{},
1821  Sequence<9>{}));
1822  }
1823 
1824  // transposed XDL output supporting C' = B' * A'
1825  // M2_N2 -> M2_N2_N3_N4
1826  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1827  __host__ __device__ static constexpr auto
1828  MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1829  {
1830  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1831  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1832  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1833  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1834  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1835 
1837  c_desc_m0_n0_m1_n1_m2_n2,
1844  Number<num_blks>{},
1845  Number<mfma_instr.group_size>{}))),
1847  Sequence<1>{},
1848  Sequence<2>{},
1849  Sequence<3>{},
1850  Sequence<4>{},
1851  Sequence<5>{}),
1853  Sequence<1>{},
1854  Sequence<2>{},
1855  Sequence<3>{},
1856  Sequence<4>{},
1857  Sequence<5, 6, 7>{}));
1858  }
1859 
1860  template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
1861  __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
1862  const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
1863  {
1864  const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
1865  const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
1866  const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
1867  const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
1868  const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
1869  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1870 
1872  c_desc_g_m0_n0_m1_n1_m2_n2,
1879  mfma_instr.num_groups_per_blk, num_blks, mfma_instr.group_size)),
1880  make_pass_through_transform(mfma_instr.num_threads_per_blk)),
1882  Sequence<1>{},
1883  Sequence<2>{},
1884  Sequence<3>{},
1885  Sequence<4>{},
1886  Sequence<5>{},
1887  Sequence<6>{}),
1889  Sequence<1>{},
1890  Sequence<2>{},
1891  Sequence<3>{},
1892  Sequence<4>{},
1894  Sequence<8>{}));
1895  }
1896 
1897  __device__ __host__ static constexpr index_t GetRegSizePerXdlops()
1898  {
1899  return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
1900  }
1901 
1902  __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
1903 
1904  template <class FloatA, class FloatB, class FloatC>
1905  __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
1906  {
1907  static_assert(
1914  "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
1915 
1916  static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
1917  if constexpr(!TransposeC)
1918  {
1919  mfma_instr.template run<MPerXdlops, NPerXdlops>(
1920  p_a_wave[k], p_b_wave[k], p_c_thread);
1921  }
1922  else
1923  {
1924  mfma_instr.template run<MPerXdlops, NPerXdlops>(
1925  p_b_wave[k], p_a_wave[k], p_c_thread);
1926  }
1927  });
1928  }
1929 
1930  template <index_t OpselA,
1931  index_t OpselB,
1932  class FloatA,
1933  class ScaleA,
1934  class FloatB,
1935  class ScaleB,
1936  class FloatC>
1937  __device__ void Run(const FloatA& p_a_wave,
1938  const ScaleA& a_scale_thread,
1939  const FloatB& p_b_wave,
1940  const ScaleB& b_scale_thread,
1941  FloatC& p_c_thread) const
1942  {
1943  static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
1944  if constexpr(!TransposeC)
1945  {
1946  mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
1947  p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
1948  }
1949  else
1950  {
1951  mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
1952  p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
1953  }
1954  });
1955  }
1956 
1957  __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; }
1958 
1959  __device__ static auto GetBlkIdx()
1960  {
1961  const auto laneId = GetLaneId();
1962  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1963 
1964  constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
1965  make_tuple(
1966  make_merge_transform(make_tuple(1, num_blks, mfma_instr.num_threads_per_blk))),
1968  make_tuple(Sequence<0>{}));
1969 
1970  const auto blk_idx =
1971  threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
1972 
1973  const auto blk_id = blk_idx[I1];
1974  const auto blk_td = blk_idx[I2];
1975 
1976  return make_tuple(blk_id, blk_td);
1977  }
1978 
1979  template <bool SwizzleA>
1980  __device__ static auto GetGfx11InputBlkIdx()
1981  {
1982  auto laneId = GetLaneId() % mfma_instr.num_threads_per_blk;
1983  if constexpr(SwizzleA)
1984  {
1985  laneId = ((laneId & 1) << 3) | (laneId >> 1);
1986  }
1987  constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
1989  make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))),
1991  make_tuple(Sequence<0>{}));
1992 
1993  const auto blk_idx =
1994  threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
1995 
1996  const auto blk_id = blk_idx[I1];
1997  const auto blk_td = blk_idx[I2];
1998 
1999  return make_tuple(blk_id, blk_td);
2000  }
2001 
2002  __host__ __device__ static auto CalculateAThreadOriginDataIndex()
2003  {
2004  const auto laneId = GetLaneId();
2005 #if defined(__gfx11__)
2006  const auto blk_idx = GetGfx11InputBlkIdx<!TransposeC>();
2007 #else
2008  const auto blk_idx = GetBlkIdx();
2009 #endif
2010 
2011  const auto blk_id = blk_idx[I0];
2012  const auto blk_td = blk_idx[I1];
2013 
2014  if constexpr(mfma_instr.is_k_reduction)
2015  {
2016  return make_tuple(blk_id, blk_td);
2017  }
2018  else
2019  {
2020  return make_tuple(0, laneId);
2021  }
2022  }
2023 
2024  __host__ __device__ static auto CalculateBThreadOriginDataIndex()
2025  {
2026  const auto laneId = GetLaneId();
2027 #if defined(__gfx11__)
2028  const auto blk_idx = GetGfx11InputBlkIdx<TransposeC>();
2029 #else
2030  const auto blk_idx = GetBlkIdx();
2031 #endif
2032 
2033  const auto blk_id = blk_idx[I0];
2034  const auto blk_td = blk_idx[I1];
2035 
2036  if constexpr(mfma_instr.is_k_reduction)
2037  {
2038  return make_tuple(blk_id, blk_td);
2039  }
2040  else
2041  {
2042  return make_tuple(0, laneId);
2043  }
2044  }
2045 
2046  __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
2047  {
2048  const auto blk_idx = GetBlkIdx();
2049 
2050  const auto blk_id = blk_idx[I0];
2051  const auto blk_td = blk_idx[I1];
2052 
2053  index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td;
2054  index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size;
2055 
2056  return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
2057  }
2058 
2059  __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
2060  {
2061  const auto blk_idx = GetBlkIdx();
2062 
2063  const auto blk_id = blk_idx[I0];
2064  const auto blk_td = blk_idx[I1];
2065 
2066  return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
2067  }
2068 
2069  // Falls back to single rate instruction on gfx950 if KPack is single rate; no change on gfx942-
2070  // when base_type is either f8_t or bf8_t, additional_type will always be either f8_t or bf8_t,
2071  // except Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
2072  static constexpr bool is_single_rate_mfma =
2074  KPack <= 4) ||
2075  (is_same<base_type, int8_t>::value && KPack <= 8) ||
2078  ? true
2079  : false;
2080  static constexpr auto mfma = MfmaSelector<base_type,
2081  MPerXdlops,
2082  NPerXdlops,
2083  additional_type,
2085  is_scale_mfma>{};
2086 
2087  static constexpr auto mfma_instr = mfma.selected_mfma;
2088 
2089  static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
2090  static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
2091  static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
2092 
2093  __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
2094  {
2095  return make_tuple(
2097  }
2098 };
2099 
2100 } // namespace ck
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
MfmaInstr
Definition: xdlops_gemm.hpp:43
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_unsupport_16x16_gfx11
@ wmma_i32_16x16x16_iu8_gfx12
@ mfma_scale_f32_32x32x64f8f6f4
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_unsupport_16x16_gfx12
@ mfma_f32_16x16x16bf16_1k
@ wmma_f32_16x16x16_f8f8_gfx12
@ mfma_scale_f32_16x16x128f8f6f4
@ mfma_f32_16x16x128f8f6f4
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
typename packed_type_info< T >::element_type element_type_t
Definition: data_type.hpp:405
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_i32_16x16x16_iu8_gfx12
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_f32_16x16x16_f8f8_gfx12
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: array.hpp:14
Definition: xdlops_gemm.hpp:1126
__host__ constexpr __device__ MfmaSelector()
Definition: xdlops_gemm.hpp:1654
static constexpr bool IsABroadcast()
Definition: xdlops_gemm.hpp:1688
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1700
static constexpr auto GetMfma()
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1647
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1694
Definition: sequence.hpp:43
Definition: xdlops_gemm.hpp:1711
static constexpr auto mfma_instr
Definition: xdlops_gemm.hpp:2087
__host__ constexpr __device__ XdlopsGemm()
Definition: xdlops_gemm.hpp:1730
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:2024
static __device__ auto GetBlkIdx()
Definition: xdlops_gemm.hpp:1959
__device__ static constexpr __host__ index_t GetRegSizePerXdlops()
Definition: xdlops_gemm.hpp:1897
static constexpr auto I2
Definition: xdlops_gemm.hpp:1714
static constexpr __device__ index_t GetNumBlks()
Definition: xdlops_gemm.hpp:1722
static __device__ auto GetLaneId()
Definition: xdlops_gemm.hpp:1957
static constexpr auto K0PerXdlops
Definition: xdlops_gemm.hpp:2091
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1783
static constexpr __device__ index_t GetNumXdlops()
Definition: xdlops_gemm.hpp:1724
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:2002
static constexpr bool is_single_rate_mfma
Definition: xdlops_gemm.hpp:2072
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition: xdlops_gemm.hpp:2059
static constexpr __device__ index_t GetWaveSize()
Definition: xdlops_gemm.hpp:1902
static __device__ auto GetGfx11InputBlkIdx()
Definition: xdlops_gemm.hpp:1980
static constexpr auto I5
Definition: xdlops_gemm.hpp:1717
static constexpr auto I3
Definition: xdlops_gemm.hpp:1715
static constexpr auto I0
Definition: xdlops_gemm.hpp:1712
__device__ void Run(const FloatA &p_a_wave, const ScaleA &a_scale_thread, const FloatB &p_b_wave, const ScaleB &b_scale_thread, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:1937
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1748
__host__ static constexpr __device__ auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1861
static constexpr auto I1
Definition: xdlops_gemm.hpp:1713
static constexpr auto K1PerXdlops
Definition: xdlops_gemm.hpp:2090
static constexpr auto KPerXdlops
Definition: xdlops_gemm.hpp:2089
static constexpr auto I4
Definition: xdlops_gemm.hpp:1716
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:1905
static constexpr auto mfma
Definition: xdlops_gemm.hpp:2080
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition: xdlops_gemm.hpp:2046
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1828
__host__ static constexpr __device__ auto GetCM0M1M2NThreadBlkLengths()
Definition: xdlops_gemm.hpp:2093
Definition: integral_constant.hpp:20
Definition: amd_xdlops.hpp:1202
Definition: amd_xdlops.hpp:303
Definition: amd_xdlops.hpp:193
Definition: amd_xdlops.hpp:70
Definition: amd_xdlops.hpp:269
Definition: amd_xdlops.hpp:1483
Definition: amd_xdlops.hpp:1609
Definition: amd_xdlops.hpp:159
Definition: amd_xdlops.hpp:1546
Definition: amd_xdlops.hpp:1420
Definition: amd_xdlops.hpp:207
Definition: amd_xdlops.hpp:56
Definition: amd_xdlops.hpp:331
Definition: amd_xdlops.hpp:249
Definition: amd_xdlops.hpp:1451
Definition: amd_xdlops.hpp:1577
Definition: amd_xdlops.hpp:139
Definition: amd_xdlops.hpp:1514
Definition: amd_xdlops.hpp:1388
Definition: amd_xdlops.hpp:15
Definition: amd_xdlops.hpp:42
Definition: amd_xdlops.hpp:317
Definition: amd_xdlops.hpp:112
Definition: amd_xdlops.hpp:481
Definition: amd_xdlops.hpp:289
Definition: amd_xdlops.hpp:179
Definition: amd_xdlops.hpp:84
Definition: amd_xdlops.hpp:221
Definition: amd_xdlops.hpp:461
Definition: amd_xdlops.hpp:364
Definition: amd_xdlops.hpp:442
Definition: amd_xdlops.hpp:403
Definition: amd_xdlops.hpp:423
Definition: amd_xdlops.hpp:383
Definition: amd_xdlops.hpp:345
Definition: amd_xdlops.hpp:886
Definition: amd_xdlops.hpp:666
Definition: amd_wmma.hpp:50
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
Definition: type.hpp:177
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:869
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:447
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:315
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:182
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:425
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:733
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:821
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:293
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:777
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:689
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:337
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:160
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:491
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:381
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:711
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:799
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:271
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:755
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:667
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:116
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:138
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:469
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:227
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:845
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:403
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:249
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:205
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:359
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:645
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:535
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:579
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:623
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:557
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:601
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:513
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:938
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:901
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:980
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1044
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1102
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1092
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:970
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1034
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1082
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1072
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:997
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1061
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition: xdlops_gemm.hpp:1008
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition: xdlops_gemm.hpp:1113
Definition: xdlops_gemm.hpp:952
static constexpr index_t n_per_blk
Definition: xdlops_gemm.hpp:961
static constexpr index_t group_size
Definition: xdlops_gemm.hpp:953
static constexpr index_t m_per_blk
Definition: xdlops_gemm.hpp:960
static constexpr bool is_k_reduction
Definition: xdlops_gemm.hpp:963
static constexpr index_t num_threads_per_blk
Definition: xdlops_gemm.hpp:956
static constexpr index_t num_output_blks
Definition: xdlops_gemm.hpp:959
static constexpr index_t wave_size
Definition: xdlops_gemm.hpp:957
static constexpr index_t num_input_blks
Definition: xdlops_gemm.hpp:958
static constexpr index_t num_groups_per_blk
Definition: xdlops_gemm.hpp:954
static constexpr index_t num_regs_per_blk
Definition: xdlops_gemm.hpp:955
static constexpr index_t k_per_blk
Definition: xdlops_gemm.hpp:962
Definition: xdlops_gemm.hpp:1016
static constexpr index_t n_per_blk
Definition: xdlops_gemm.hpp:1025
static constexpr index_t group_size
Definition: xdlops_gemm.hpp:1017
static constexpr index_t num_output_blks
Definition: xdlops_gemm.hpp:1023
static constexpr index_t m_per_blk
Definition: xdlops_gemm.hpp:1024
static constexpr index_t num_threads_per_blk
Definition: xdlops_gemm.hpp:1020
static constexpr bool is_k_reduction
Definition: xdlops_gemm.hpp:1027
static constexpr index_t num_regs_per_blk
Definition: xdlops_gemm.hpp:1019
static constexpr index_t num_groups_per_blk
Definition: xdlops_gemm.hpp:1018
static constexpr index_t num_input_blks
Definition: xdlops_gemm.hpp:1022
static constexpr index_t wave_size
Definition: xdlops_gemm.hpp:1021
static constexpr index_t k_per_blk
Definition: xdlops_gemm.hpp:1026
Definition: xdlops_gemm.hpp:98
Definition: functional2.hpp:33