/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File
xdlops_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math.hpp"
10 
11 namespace ck {
15 template <typename T>
16 static constexpr bool is_scale_mfma_data_type()
17 {
18  using U = element_type_t<T>;
19  return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
20  is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
21 }
22 
23 #ifndef CK_CODE_GEN_RTC
27 template <typename T>
28 static constexpr bool is_scale_mfma_scale_type()
29 {
30  return is_same_v<T, e8m0_bexp_t>;
31 }
32 #endif
33 
37 template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
38 static constexpr bool scale_mfma_hw_support()
39 {
40  return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
41  is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
42 }
43 
44 enum struct MfmaInstr
45 {
83  mfma_f32_16x16x8xf32, // tf32
85  // gfx11
90  // gfx12
99 };
100 
101 template <MfmaInstr instr>
102 struct mfma_type;
103 
104 template <>
106 {
107  static constexpr index_t group_size = 4;
108  static constexpr index_t num_groups_per_blk = 4;
109  static constexpr index_t num_regs_per_blk = 16;
110  static constexpr index_t num_threads_per_blk = 32;
111  static constexpr index_t wave_size = 64;
112  static constexpr index_t num_input_blks = 2;
113  static constexpr index_t num_output_blks = 2;
114  static constexpr index_t m_per_blk = 32;
115  static constexpr index_t n_per_blk = 32;
116  static constexpr index_t k_per_blk = 1;
117  static constexpr bool is_k_reduction = false;
118 
119  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
120  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
121  {
123  }
124 };
125 
126 template <>
128 {
129  static constexpr index_t group_size = 4;
130  static constexpr index_t num_groups_per_blk = 4;
131  static constexpr index_t num_regs_per_blk = 16;
132  static constexpr index_t num_threads_per_blk = 32;
133  static constexpr index_t wave_size = 64;
134  static constexpr index_t num_input_blks = 2;
135  static constexpr index_t num_output_blks = 1;
136  static constexpr index_t m_per_blk = 32;
137  static constexpr index_t n_per_blk = 32;
138  static constexpr index_t k_per_blk = 1;
139  static constexpr bool is_k_reduction = true;
140 
141  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
142  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
143  {
145  }
146 };
147 
148 template <>
150 {
151  static constexpr index_t group_size = 4;
152  static constexpr index_t num_groups_per_blk = 1;
153  static constexpr index_t num_regs_per_blk = 4;
154  static constexpr index_t num_threads_per_blk = 16;
155  static constexpr index_t wave_size = 64;
156  static constexpr index_t num_input_blks = 4;
157  static constexpr index_t num_output_blks = 1;
158  static constexpr index_t m_per_blk = 16;
159  static constexpr index_t n_per_blk = 16;
160  static constexpr index_t k_per_blk = 1;
161  static constexpr bool is_k_reduction = true;
162 
163  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
164  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
165  {
167  }
168 };
169 
170 template <>
172 {
173  static constexpr index_t group_size = 4;
174  static constexpr index_t num_groups_per_blk = 1;
175  static constexpr index_t num_regs_per_blk = 4;
176  static constexpr index_t num_threads_per_blk = 16;
177  static constexpr index_t wave_size = 64;
178  static constexpr index_t num_input_blks = 4;
179  static constexpr index_t num_output_blks = 4;
180  static constexpr index_t m_per_blk = 16;
181  static constexpr index_t n_per_blk = 16;
182  static constexpr index_t k_per_blk = 1;
183  static constexpr bool is_k_reduction = false;
184 
185  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
186  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
187  {
189  }
190 };
191 
192 // treat 4x4x1 as a single-blk 4x64 mfma
193 template <>
195 {
196  static constexpr index_t group_size = 4;
197  static constexpr index_t num_groups_per_blk = 1;
198  static constexpr index_t num_regs_per_blk = 4;
199  static constexpr index_t num_threads_per_blk = 64;
200  static constexpr index_t wave_size = 64;
201  static constexpr index_t num_input_blks = 1;
202  static constexpr index_t num_output_blks = 1;
203  static constexpr index_t m_per_blk = 4;
204  static constexpr index_t n_per_blk = 64;
205  static constexpr index_t k_per_blk = 1;
206  static constexpr bool is_k_reduction = false;
207 
208  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
209  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
210  {
212  }
213 };
214 
215 template <>
217 {
218  static constexpr index_t group_size = 4;
219  static constexpr index_t num_groups_per_blk = 4;
220  static constexpr index_t num_regs_per_blk = 16;
221  static constexpr index_t num_threads_per_blk = 32;
222  static constexpr index_t wave_size = 64;
223  static constexpr index_t num_input_blks = 2;
224  static constexpr index_t num_output_blks = 2;
225  static constexpr index_t m_per_blk = 32;
226  static constexpr index_t n_per_blk = 32;
227  static constexpr index_t k_per_blk = 4;
228  static constexpr bool is_k_reduction = false;
229 
230  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
231  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
232  {
234  }
235 };
236 
237 template <>
239 {
240  static constexpr index_t group_size = 4;
241  static constexpr index_t num_groups_per_blk = 4;
242  static constexpr index_t num_regs_per_blk = 16;
243  static constexpr index_t num_threads_per_blk = 32;
244  static constexpr index_t wave_size = 64;
245  static constexpr index_t num_input_blks = 2;
246  static constexpr index_t num_output_blks = 1;
247  static constexpr index_t m_per_blk = 32;
248  static constexpr index_t n_per_blk = 32;
249  static constexpr index_t k_per_blk = 4;
250  static constexpr bool is_k_reduction = true;
251 
252  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
253  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
254  {
256  }
257 };
258 
259 template <>
261 {
262  static constexpr index_t group_size = 4;
263  static constexpr index_t num_groups_per_blk = 4;
264  static constexpr index_t num_regs_per_blk = 16;
265  static constexpr index_t num_threads_per_blk = 32;
266  static constexpr index_t wave_size = 64;
267  static constexpr index_t num_input_blks = 2;
268  static constexpr index_t num_output_blks = 1;
269  static constexpr index_t m_per_blk = 32;
270  static constexpr index_t n_per_blk = 32;
271  static constexpr index_t k_per_blk = 8;
272  static constexpr bool is_k_reduction = true;
273 
274  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
275  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
276  {
278  }
279 };
280 
281 template <>
283 {
284  static constexpr index_t group_size = 4;
285  static constexpr index_t num_groups_per_blk = 1;
286  static constexpr index_t num_regs_per_blk = 4;
287  static constexpr index_t num_threads_per_blk = 16;
288  static constexpr index_t wave_size = 64;
289  static constexpr index_t num_input_blks = 4;
290  static constexpr index_t num_output_blks = 1;
291  static constexpr index_t m_per_blk = 16;
292  static constexpr index_t n_per_blk = 16;
293  static constexpr index_t k_per_blk = 8;
294  static constexpr bool is_k_reduction = true;
295 
296  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
297  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
298  {
300  }
301 };
302 
303 template <>
305 {
306  static constexpr index_t group_size = 4;
307  static constexpr index_t num_groups_per_blk = 1;
308  static constexpr index_t num_regs_per_blk = 4;
309  static constexpr index_t num_threads_per_blk = 16;
310  static constexpr index_t wave_size = 64;
311  static constexpr index_t num_input_blks = 4;
312  static constexpr index_t num_output_blks = 1;
313  static constexpr index_t m_per_blk = 16;
314  static constexpr index_t n_per_blk = 16;
315  static constexpr index_t k_per_blk = 4;
316  static constexpr bool is_k_reduction = true;
317 
318  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
319  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
320  {
322  }
323 };
324 
325 template <>
327 {
328  static constexpr index_t group_size = 4;
329  static constexpr index_t num_groups_per_blk = 1;
330  static constexpr index_t num_regs_per_blk = 4;
331  static constexpr index_t num_threads_per_blk = 16;
332  static constexpr index_t wave_size = 64;
333  static constexpr index_t num_input_blks = 4;
334  static constexpr index_t num_output_blks = 4;
335  static constexpr index_t m_per_blk = 16;
336  static constexpr index_t n_per_blk = 16;
337  static constexpr index_t k_per_blk = 4;
338  static constexpr bool is_k_reduction = false;
339 
340  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
341  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
342  {
344  }
345 };
346 
347 template <>
349 {
350  static constexpr index_t group_size = 4;
351  static constexpr index_t num_groups_per_blk = 1;
352  static constexpr index_t num_regs_per_blk = 4;
353  static constexpr index_t num_threads_per_blk = 64;
354  static constexpr index_t wave_size = 64;
355  static constexpr index_t num_input_blks = 1;
356  static constexpr index_t num_output_blks = 1;
357  static constexpr index_t m_per_blk = 4;
358  static constexpr index_t n_per_blk = 64;
359  static constexpr index_t k_per_blk = 4;
360  static constexpr bool is_k_reduction = false;
361 
362  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
363  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
364  {
366  }
367 };
368 
369 template <>
371 {
372  static constexpr index_t group_size = 4;
373  static constexpr index_t num_groups_per_blk = 4;
374  static constexpr index_t num_regs_per_blk = 16;
375  static constexpr index_t num_threads_per_blk = 32;
376  static constexpr index_t wave_size = 64;
377  static constexpr index_t num_input_blks = 2;
378  static constexpr index_t num_output_blks = 1;
379  static constexpr index_t m_per_blk = 32;
380  static constexpr index_t n_per_blk = 32;
381  static constexpr index_t k_per_blk = 8;
382  static constexpr bool is_k_reduction = true;
383 
384  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
385  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
386  {
388  }
389 };
390 
391 template <>
393 {
394  static constexpr index_t group_size = 4;
395  static constexpr index_t num_groups_per_blk = 4;
396  static constexpr index_t num_regs_per_blk = 16;
397  static constexpr index_t num_threads_per_blk = 32;
398  static constexpr index_t wave_size = 64;
399  static constexpr index_t num_input_blks = 2;
400  static constexpr index_t num_output_blks = 1;
401  static constexpr index_t m_per_blk = 32;
402  static constexpr index_t n_per_blk = 32;
403  static constexpr index_t k_per_blk = 4;
404  static constexpr bool is_k_reduction = true;
405 
406  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
407  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
408  {
410  }
411 };
412 
413 template <>
415 {
416  static constexpr index_t group_size = 4;
417  static constexpr index_t num_groups_per_blk = 1;
418  static constexpr index_t num_regs_per_blk = 4;
419  static constexpr index_t num_threads_per_blk = 16;
420  static constexpr index_t wave_size = 64;
421  static constexpr index_t num_input_blks = 4;
422  static constexpr index_t num_output_blks = 1;
423  static constexpr index_t m_per_blk = 16;
424  static constexpr index_t n_per_blk = 16;
425  static constexpr index_t k_per_blk = 8;
426  static constexpr bool is_k_reduction = true;
427 
428  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
429  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
430  {
432  }
433 };
434 
435 template <>
437 {
438  static constexpr index_t group_size = 4;
439  static constexpr index_t num_groups_per_blk = 1;
440  static constexpr index_t num_regs_per_blk = 4;
441  static constexpr index_t num_threads_per_blk = 16;
442  static constexpr index_t wave_size = 64;
443  static constexpr index_t num_input_blks = 4;
444  static constexpr index_t num_output_blks = 1;
445  static constexpr index_t m_per_blk = 16;
446  static constexpr index_t n_per_blk = 16;
447  static constexpr index_t k_per_blk = 4;
448  static constexpr bool is_k_reduction = true;
449 
450  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
451  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
452  {
454  }
455 };
456 
457 template <>
459 {
460  static constexpr index_t group_size = 4;
461  static constexpr index_t num_groups_per_blk = 4;
462  static constexpr index_t num_regs_per_blk = 16;
463  static constexpr index_t num_threads_per_blk = 32;
464  static constexpr index_t wave_size = 64;
465  static constexpr index_t num_input_blks = 2;
466  static constexpr index_t num_output_blks = 1;
467  static constexpr index_t m_per_blk = 32;
468  static constexpr index_t n_per_blk = 32;
469  static constexpr index_t k_per_blk = 2;
470  static constexpr bool is_k_reduction = true;
471 
472  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
473  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
474  {
476  }
477 };
478 
479 template <>
481 {
482  static constexpr index_t group_size = 4;
483  static constexpr index_t num_groups_per_blk = 1;
484  static constexpr index_t num_regs_per_blk = 4;
485  static constexpr index_t num_threads_per_blk = 16;
486  static constexpr index_t wave_size = 64;
487  static constexpr index_t num_input_blks = 4;
488  static constexpr index_t num_output_blks = 1;
489  static constexpr index_t m_per_blk = 16;
490  static constexpr index_t n_per_blk = 16;
491  static constexpr index_t k_per_blk = 2;
492  static constexpr bool is_k_reduction = true;
493 
494  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
495  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
496  {
498  }
499 };
500 
501 template <>
503 {
504  static constexpr index_t group_size = 4;
505  static constexpr index_t num_groups_per_blk = 4;
506  static constexpr index_t num_regs_per_blk = 16;
507  static constexpr index_t num_threads_per_blk = 32;
508  static constexpr index_t wave_size = 64;
509  static constexpr index_t num_input_blks = 2;
510  static constexpr index_t num_output_blks = 1;
511  static constexpr index_t m_per_blk = 32;
512  static constexpr index_t n_per_blk = 32;
513  static constexpr index_t k_per_blk = 4;
514  static constexpr bool is_k_reduction = true;
515 
516  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
517  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
518  {
520  }
521 };
522 
523 template <>
525 {
526  static constexpr index_t group_size = 4;
527  static constexpr index_t num_groups_per_blk = 1;
528  static constexpr index_t num_regs_per_blk = 4;
529  static constexpr index_t num_threads_per_blk = 16;
530  static constexpr index_t wave_size = 64;
531  static constexpr index_t num_input_blks = 4;
532  static constexpr index_t num_output_blks = 1;
533  static constexpr index_t m_per_blk = 16;
534  static constexpr index_t n_per_blk = 16;
535  static constexpr index_t k_per_blk = 4;
536  static constexpr bool is_k_reduction = true;
537 
538  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
539  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
540  {
542  }
543 };
544 
545 template <>
547 {
548  static constexpr index_t group_size = 4;
549  static constexpr index_t num_groups_per_blk = 4;
550  static constexpr index_t num_regs_per_blk = 16;
551  static constexpr index_t num_threads_per_blk = 32;
552  static constexpr index_t wave_size = 64;
553  static constexpr index_t num_input_blks = 2;
554  static constexpr index_t num_output_blks = 1;
555  static constexpr index_t m_per_blk = 32;
556  static constexpr index_t n_per_blk = 32;
557  static constexpr index_t k_per_blk = 8;
558  static constexpr bool is_k_reduction = true;
559 
560  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
561  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
562  {
564  }
565 };
566 
567 template <>
569 {
570  static constexpr index_t group_size = 4;
571  static constexpr index_t num_groups_per_blk = 1;
572  static constexpr index_t num_regs_per_blk = 4;
573  static constexpr index_t num_threads_per_blk = 16;
574  static constexpr index_t wave_size = 64;
575  static constexpr index_t num_input_blks = 4;
576  static constexpr index_t num_output_blks = 1;
577  static constexpr index_t m_per_blk = 16;
578  static constexpr index_t n_per_blk = 16;
579  static constexpr index_t k_per_blk = 8;
580  static constexpr bool is_k_reduction = true;
581 
582  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
583  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
584  {
586  }
587 };
588 
589 template <>
591 {
592  static constexpr index_t group_size = 4;
593  static constexpr index_t num_groups_per_blk = 4;
594  static constexpr index_t num_regs_per_blk = 16;
595  static constexpr index_t num_threads_per_blk = 32;
596  static constexpr index_t wave_size = 64;
597  static constexpr index_t num_input_blks = 2;
598  static constexpr index_t num_output_blks = 1;
599  static constexpr index_t m_per_blk = 32;
600  static constexpr index_t n_per_blk = 32;
601  static constexpr index_t k_per_blk = 16;
602  static constexpr bool is_k_reduction = true;
603 
604  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
605  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
606  {
608  }
609 };
610 
611 template <>
613 {
614  static constexpr index_t group_size = 4;
615  static constexpr index_t num_groups_per_blk = 1;
616  static constexpr index_t num_regs_per_blk = 4;
617  static constexpr index_t num_threads_per_blk = 16;
618  static constexpr index_t wave_size = 64;
619  static constexpr index_t num_input_blks = 4;
620  static constexpr index_t num_output_blks = 1;
621  static constexpr index_t m_per_blk = 16;
622  static constexpr index_t n_per_blk = 16;
623  static constexpr index_t k_per_blk = 16;
624  static constexpr bool is_k_reduction = true;
625 
626  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
627  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
628  {
630  }
631 };
632 
633 template <>
635 {
636  static constexpr index_t group_size = 1;
637  static constexpr index_t num_groups_per_blk = 4;
638  static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk;
639  static constexpr index_t num_threads_per_blk = 16;
640  static constexpr index_t wave_size = 64;
641  static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk;
642  static constexpr index_t num_output_blks = 1;
643  static constexpr index_t m_per_blk = 16;
644  static constexpr index_t n_per_blk = 16;
645  static constexpr index_t k_per_blk = 1;
646  static constexpr bool is_k_reduction = true;
647 
648  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
649  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
650  {
652  }
653 };
654 
655 template <>
657 {
658  static constexpr index_t group_size = 4;
659  static constexpr index_t num_groups_per_blk = 4;
660  static constexpr index_t num_regs_per_blk = 16;
661  static constexpr index_t num_threads_per_blk = 32;
662  static constexpr index_t wave_size = 64;
663  static constexpr index_t num_input_blks = 2;
664  static constexpr index_t num_output_blks = 1;
665  static constexpr index_t m_per_blk = 32;
666  static constexpr index_t n_per_blk = 32;
667  static constexpr index_t k_per_blk = 8;
668  static constexpr bool is_k_reduction = true;
669 
670  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
671  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
672  {
674  }
675 };
676 
677 template <>
679 {
680  static constexpr index_t group_size = 4;
681  static constexpr index_t num_groups_per_blk = 1;
682  static constexpr index_t num_regs_per_blk = 4;
683  static constexpr index_t num_threads_per_blk = 16;
684  static constexpr index_t wave_size = 64;
685  static constexpr index_t num_input_blks = 4;
686  static constexpr index_t num_output_blks = 1;
687  static constexpr index_t m_per_blk = 16;
688  static constexpr index_t n_per_blk = 16;
689  static constexpr index_t k_per_blk = 8;
690  static constexpr bool is_k_reduction = true;
691 
692  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
693  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
694  {
696  }
697 };
698 
699 template <>
701 {
702  static constexpr index_t group_size = 4;
703  static constexpr index_t num_groups_per_blk = 4;
704  static constexpr index_t num_regs_per_blk = 16;
705  static constexpr index_t num_threads_per_blk = 32;
706  static constexpr index_t wave_size = 64;
707  static constexpr index_t num_input_blks = 2;
708  static constexpr index_t num_output_blks = 1;
709  static constexpr index_t m_per_blk = 32;
710  static constexpr index_t n_per_blk = 32;
711  static constexpr index_t k_per_blk = 8;
712  static constexpr bool is_k_reduction = true;
713 
714  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
715  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
716  {
718  }
719 };
720 
721 template <>
723 {
724  static constexpr index_t group_size = 4;
725  static constexpr index_t num_groups_per_blk = 1;
726  static constexpr index_t num_regs_per_blk = 4;
727  static constexpr index_t num_threads_per_blk = 16;
728  static constexpr index_t wave_size = 64;
729  static constexpr index_t num_input_blks = 4;
730  static constexpr index_t num_output_blks = 1;
731  static constexpr index_t m_per_blk = 16;
732  static constexpr index_t n_per_blk = 16;
733  static constexpr index_t k_per_blk = 8;
734  static constexpr bool is_k_reduction = true;
735 
736  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
737  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
738  {
740  }
741 };
742 
743 template <>
745 {
746  static constexpr index_t group_size = 4;
747  static constexpr index_t num_groups_per_blk = 4;
748  static constexpr index_t num_regs_per_blk = 16;
749  static constexpr index_t num_threads_per_blk = 32;
750  static constexpr index_t wave_size = 64;
751  static constexpr index_t num_input_blks = 2;
752  static constexpr index_t num_output_blks = 1;
753  static constexpr index_t m_per_blk = 32;
754  static constexpr index_t n_per_blk = 32;
755  static constexpr index_t k_per_blk = 8;
756  static constexpr bool is_k_reduction = true;
757 
758  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
759  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
760  {
762  }
763 };
764 
765 template <>
767 {
768  static constexpr index_t group_size = 4;
769  static constexpr index_t num_groups_per_blk = 1;
770  static constexpr index_t num_regs_per_blk = 4;
771  static constexpr index_t num_threads_per_blk = 16;
772  static constexpr index_t wave_size = 64;
773  static constexpr index_t num_input_blks = 4;
774  static constexpr index_t num_output_blks = 1;
775  static constexpr index_t m_per_blk = 16;
776  static constexpr index_t n_per_blk = 16;
777  static constexpr index_t k_per_blk = 8;
778  static constexpr bool is_k_reduction = true;
779 
780  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
781  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
782  {
784  }
785 };
786 
787 template <>
789 {
790  static constexpr index_t group_size = 4;
791  static constexpr index_t num_groups_per_blk = 4;
792  static constexpr index_t num_regs_per_blk = 16;
793  static constexpr index_t num_threads_per_blk = 32;
794  static constexpr index_t wave_size = 64;
795  static constexpr index_t num_input_blks = 2;
796  static constexpr index_t num_output_blks = 1;
797  static constexpr index_t m_per_blk = 32;
798  static constexpr index_t n_per_blk = 32;
799  static constexpr index_t k_per_blk = 8;
800  static constexpr bool is_k_reduction = true;
801 
802  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
803  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
804  {
806  }
807 };
808 
809 template <>
811 {
812  static constexpr index_t group_size = 4;
813  static constexpr index_t num_groups_per_blk = 1;
814  static constexpr index_t num_regs_per_blk = 4;
815  static constexpr index_t num_threads_per_blk = 16;
816  static constexpr index_t wave_size = 64;
817  static constexpr index_t num_input_blks = 4;
818  static constexpr index_t num_output_blks = 1;
819  static constexpr index_t m_per_blk = 16;
820  static constexpr index_t n_per_blk = 16;
821  static constexpr index_t k_per_blk = 8;
822  static constexpr bool is_k_reduction = true;
823 
824  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
825  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
826  {
828  }
829 };
830 
831 template <>
833 {
834  // clang-format off
835  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
836  static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
837  static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
838  static constexpr index_t num_threads_per_blk = 32; // n_per_blk
839  static constexpr index_t wave_size = 64; // fixed
840  static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
841  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
842  static constexpr index_t m_per_blk = 32; // from the instruction
843  static constexpr index_t n_per_blk = 32; // from the instruction
844  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
845  static constexpr bool is_k_reduction = true; // ???
846  // clang-format on
847 
848  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
849  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
850  {
852  }
853 };
854 
855 template <>
857 {
858  // clang-format off
859  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
860  static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
861  static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
862  static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
863  static constexpr index_t wave_size = 64; // fixed
864  static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
865  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
866  static constexpr index_t m_per_blk = 16; // from the instruction
867  static constexpr index_t n_per_blk = 16; // from the instruction
868  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
869  static constexpr bool is_k_reduction = true; // ???
870  // clang-format on
871 
872  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
873  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
874  {
876  }
877 };
878 
879 template <>
881 {
882  // clang-format off
883  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
884  static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
885  static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
886  static constexpr index_t num_threads_per_blk = 32; // n_per_blk
887  static constexpr index_t wave_size = 64; // fixed
888  static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
889  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
890  static constexpr index_t m_per_blk = 32; // from the instruction
891  static constexpr index_t n_per_blk = 32; // from the instruction
892  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
893  static constexpr bool is_k_reduction = true; // ???
894  // clang-format on
895 
896  template <index_t MPerXdlops,
897  index_t NPerXdlops,
898  index_t OpselA,
899  index_t OpselB,
900  class FloatA,
901  class ScaleA,
902  class FloatB,
903  class ScaleB,
904  class FloatC>
905  __device__ void run(const FloatA& a,
906  const ScaleA& scale_a,
907  const FloatB& b,
908  const ScaleB& scale_b,
909  FloatC& reg_c) const
910  {
912  a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
913  }
914 };
915 
916 template <>
918 {
919  // clang-format off
920  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
921  static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
922  static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
923  static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
924  static constexpr index_t wave_size = 64; // fixed
925  static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
926  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
927  static constexpr index_t m_per_blk = 16; // from the instruction
928  static constexpr index_t n_per_blk = 16; // from the instruction
929  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
930  static constexpr bool is_k_reduction = true; // ???
931  // clang-format on
932 
933  template <index_t MPerXdlops,
934  index_t NPerXdlops,
935  index_t OpselA,
936  index_t OpselB,
937  class FloatA,
938  class ScaleA,
939  class FloatB,
940  class ScaleB,
941  class FloatC>
942  __device__ void run(const FloatA& a,
943  const ScaleA& scale_a,
944  const FloatB& b,
945  const ScaleB& scale_b,
946  FloatC& reg_c) const
947  {
948 
950  a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
951  }
952 };
953 
973 template <>
975 {
976  static constexpr index_t wave_size = 64; // fixed
977  static constexpr index_t m_per_blk = 16; // from the instruction
978  static constexpr index_t n_per_blk = 16; // from the instruction
979  static constexpr index_t num_threads_per_blk = n_per_blk; // 16
980  static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 4
981  static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 4
982  static constexpr index_t group_size = 4;
983  static constexpr index_t num_groups_per_blk = 1;
984  static constexpr index_t num_output_blks = 1;
985  static constexpr index_t k_per_blk = 2; // k_per_blk(K1PerXdlops) should be 2.
986  static constexpr bool is_k_reduction = true;
987 
988  // AB register size : 2, register size: 4
989  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
990  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
991  {
993  }
994 };
995 
996 template <>
998 {
999  static constexpr index_t wave_size = 64; // fixed
1000  static constexpr index_t m_per_blk = 32; // from the instruction
1001  static constexpr index_t n_per_blk = 32; // from the instruction
1002  static constexpr index_t num_threads_per_blk = n_per_blk; // 32
1003  static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 16
1004  static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 2
1005  static constexpr index_t group_size = 4; // corresponding to CD rows mapping
1006  static constexpr index_t num_groups_per_blk = 4;
1007  static constexpr index_t num_output_blks = 1;
1008  static constexpr index_t k_per_blk = 2;
1009  static constexpr bool is_k_reduction = true;
1010  // AB register size: 2, CD register size: 16
1011  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
1012  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1013  {
1015  }
1016 };
1017 
1018 // gfx11
1020 {
1021  static constexpr index_t group_size = 8;
1022  static constexpr index_t num_groups_per_blk = 1;
1023  static constexpr index_t num_regs_per_blk = 8;
1024  static constexpr index_t num_threads_per_blk = 16;
1025  static constexpr index_t wave_size = 32;
1026  static constexpr index_t num_input_blks = 1;
1027  static constexpr index_t num_output_blks = 1;
1028  static constexpr index_t m_per_blk = 16;
1029  static constexpr index_t n_per_blk = 16;
1030  static constexpr index_t k_per_blk = 16;
1031  static constexpr bool is_k_reduction = true;
1032 };
1033 
1034 template <>
1036 {
1037  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1038  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1039  {
1041  }
1042 };
1043 
1044 template <>
1046 {
1047  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1048  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1049  {
1051  }
1052 };
1053 
1054 template <>
1056 {
1057  template <index_t MPerWmma,
1058  index_t NPerWmma,
1059  class FloatA,
1060  class FloatB,
1061  class FloatC,
1062  bool neg_a = true,
1063  bool neg_b = true,
1064  bool clamp = false>
1065  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1066  {
1068  }
1069 };
1070 
1071 template <>
1073 {
1074  static constexpr index_t k_per_blk = 2;
1075  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1076  __device__ void run(const FloatA&, const FloatB&, FloatC&) const
1077  {
1078  // empty for all unsupported types.
1079  }
1080 };
1081 
1082 // gfx12
1084 {
1085  static constexpr index_t group_size = 8;
1086  static constexpr index_t num_groups_per_blk = 1;
1087  static constexpr index_t num_regs_per_blk = 8;
1088  static constexpr index_t num_threads_per_blk = 16;
1089  static constexpr index_t wave_size = 32;
1090  static constexpr index_t num_input_blks = 2;
1091  static constexpr index_t num_output_blks = 1;
1092  static constexpr index_t m_per_blk = 16;
1093  static constexpr index_t n_per_blk = 16;
1094  static constexpr index_t k_per_blk = 8;
1095  static constexpr bool is_k_reduction = true;
1096 };
1097 
1098 template <>
1100 {
1101  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1102  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1103  {
1105  }
1106 };
1107 
1108 template <>
1110 {
1111  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1112  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1113  {
1115  }
1116 };
1117 
1118 template <>
1120 {
1121  template <index_t MPerWmma,
1122  index_t NPerWmma,
1123  class FloatA,
1124  class FloatB,
1125  class FloatC,
1126  bool neg_a = true,
1127  bool neg_b = true,
1128  bool clamp = false>
1129  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1130  {
1132  a, b, reg_c);
1133  }
1134 };
1135 
1136 template <>
1138 {
1139  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1140  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1141  {
1143  }
1144 };
1145 
1146 template <>
1148 {
1149  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1150  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1151  {
1153  }
1154 };
1155 
1156 template <>
1158 {
1159  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1160  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1161  {
1163  }
1164 };
1165 
1166 template <>
1168 {
1169  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1170  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1171  {
1173  }
1174 };
1175 
1176 template <>
1178 {
1179  static constexpr index_t k_per_blk = 2;
1180  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1181  __device__ void run(const FloatA&, const FloatB&, FloatC&) const
1182  {
1183  // empty for all unsupported types.
1184  }
1185 };
1186 
1201 template <typename base_type,
1202  index_t MPerXdlops,
1203  index_t NPerXdlops,
1204  typename additional_type = base_type,
1205  bool is_single_rate_mfma = false,
1206  bool is_scale_mfma = false>
1208 {
1209  template <typename base_type_,
1210  index_t MPerXdlops_,
1211  index_t NPerXdlops_,
1212  typename additional_type_ = base_type_,
1213  bool is_single_rate_mfma_ = false,
1214  bool is_scale_mfma_ = false>
1215  static constexpr auto GetMfma();
1216 
1217  template <>
1218  constexpr auto GetMfma<double, 16, 16>()
1219  {
1220 #if defined(__gfx12__)
1222 #elif defined(__gfx11__)
1224 #else
1226 #endif
1227  }
1228 
1229  template <>
1230  constexpr auto GetMfma<float, 64, 64>()
1231  {
1233  }
1234 
1235  template <>
1236  constexpr auto GetMfma<float, 32, 64>()
1237  {
1239  }
1240 
1241  template <>
1242  constexpr auto GetMfma<float, 16, 64>()
1243  {
1245  }
1246 
1247  template <>
1248  constexpr auto GetMfma<float, 8, 64>()
1249  {
1251  }
1252 
1253  template <>
1254  constexpr auto GetMfma<float, 4, 64>()
1255  {
1257  }
1258 
1259  template <>
1260  constexpr auto GetMfma<float, 32, 32>()
1261  {
1263  }
1264 
1265  template <>
1266  constexpr auto GetMfma<float, 16, 16>()
1267  {
1268 #if defined(__gfx12__)
1270 #elif defined(__gfx11__)
1272 #else
1274 #endif
1275  }
1276 
1277  template <>
1278  constexpr auto GetMfma<tf32_t, 32, 32>()
1279  {
1280 #if defined(__gfx12__)
1282 #elif defined(__gfx11__)
1284 #elif defined(__gfx942__)
1286 #else
1288 #endif
1289  }
1290 
1291  template <>
1292  constexpr auto GetMfma<tf32_t, 16, 16>()
1293  {
1294 #if defined(__gfx12__)
1296 #elif defined(__gfx11__)
1298 #elif defined(__gfx942__)
1300 #else
1302 #endif
1303  }
1304 
1305  template <>
1306  constexpr auto GetMfma<half_t, 64, 64>()
1307  {
1309  }
1310 
1311  template <>
1312  constexpr auto GetMfma<half_t, 32, 64>()
1313  {
1315  }
1316 
1317  template <>
1318  constexpr auto GetMfma<half_t, 32, 32, half_t, false>()
1319  {
1320 #if defined(__gfx950__)
1322 #else
1324 #endif
1325  }
1326  template <>
1327  constexpr auto GetMfma<half_t, 32, 32, half_t, true>()
1328  {
1330  }
1331 
1332  template <>
1333  constexpr auto GetMfma<half_t, 16, 16, half_t, false>()
1334  {
1335 #if defined(__gfx12__)
1337 #elif defined(__gfx11__)
1339 #elif defined(__gfx950__)
1341 #else
1343 #endif
1344  }
1345 
1346  template <>
1347  constexpr auto GetMfma<half_t, 16, 16, half_t, true>()
1348  {
1349 #if defined(__gfx12__)
1351 #elif defined(__gfx11__)
1353 #else
1355 #endif
1356  }
1357 
1358  template <>
1359  constexpr auto GetMfma<half_t, 16, 64>()
1360  {
1362  }
1363 
1364  template <>
1365  constexpr auto GetMfma<half_t, 8, 64>()
1366  {
1368  }
1369 
1370  template <>
1371  constexpr auto GetMfma<half_t, 4, 64>()
1372  {
1374  }
1375 
1376  template <>
1377  constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
1378  {
1379 #if defined(__gfx950__)
1381 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1383 #else
1385 #endif
1386  }
1387 
1388  template <>
1389  constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
1390  {
1391 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1393 #else
1395 #endif
1396  }
1397 
1398  template <>
1399  constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
1400  {
1401 #if defined(__gfx12__)
1403 #elif defined(__gfx11__)
1405 #elif defined(__gfx950__)
1407 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1409 #else
1411 #endif
1412  }
1413 
1414  template <>
1415  constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
1416  {
1417 #if defined(__gfx12__)
1419 #elif defined(__gfx11__)
1421 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1423 #else
1425 #endif
1426  }
1427 
1428  template <>
1429  constexpr auto GetMfma<int8_t, 32, 32, int8_t, false>()
1430  {
1431 #if defined(__gfx950__)
1433 #elif defined(__gfx942__)
1435 #else
1437 #endif
1438  }
1439 
1440  template <>
1441  constexpr auto GetMfma<int8_t, 32, 32, int8_t, true>()
1442  {
1443 #if defined(__gfx942__) || defined(__gfx950__)
1445 #else
1447 #endif
1448  }
1449 
1450  template <>
1451  constexpr auto GetMfma<int8_t, 16, 16, int8_t, false>()
1452  {
1453 #if defined(__gfx12__)
1455 #elif defined(__gfx11__)
1457 #elif defined(__gfx950__)
1459 #elif defined(__gfx942__)
1461 #else
1463 #endif
1464  }
1465 
1466  template <>
1467  constexpr auto GetMfma<int8_t, 16, 16, int8_t, true>()
1468  {
1469 #if defined(__gfx12__)
1471 #elif defined(__gfx11__)
1473 #elif defined(__gfx942__) || defined(__gfx950__)
1475 #else
1477 #endif
1478  }
1479 
1480  template <>
1481  constexpr auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
1482  {
1484  }
1485 
1486  template <>
1487  constexpr auto GetMfma<f8_t, 32, 32, f8_t, false, false>()
1488  {
1489 #if defined(__gfx950__)
1491 #else
1493 #endif
1494  }
1495 
1496  template <>
1497  constexpr auto GetMfma<f8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
1498  {
1500  }
1501 
1502  template <>
1503  constexpr auto GetMfma<bf8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
1504  {
1506  }
1507  template <>
1508  constexpr auto GetMfma<f4_t, 32, 32, f4_t, is_single_rate_mfma, true>()
1509  {
1511  }
1512  template <>
1513  constexpr auto GetMfma<f4_t, 16, 16, f4_t, is_single_rate_mfma, true>()
1514  {
1515 #if defined(__gfx12__)
1517 #elif defined(__gfx11__)
1519 #else
1521 #endif
1522  }
1523 
1524  template <>
1525  constexpr auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
1526  {
1527 #if defined(__gfx12__)
1529 #elif defined(__gfx11__)
1531 #else
1533 #endif
1534  }
1535 
1536  template <>
1537  constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
1538  {
1539 #if defined(__gfx12__)
1541 #elif defined(__gfx11__)
1543 #elif defined(__gfx950__)
1545 #else
1547 #endif
1548  }
1549 
1550  template <>
1551  constexpr auto GetMfma<f8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
1552  {
1553 #if defined(__gfx12__)
1555 #elif defined(__gfx11__)
1557 #else
1559 #endif
1560  }
1561 
1562  template <>
1563  constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
1564  {
1565 #if defined(__gfx12__)
1567 #elif defined(__gfx11__)
1569 #else
1571 #endif
1572  }
1573 
1574  template <>
1575  constexpr auto GetMfma<f8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
1576  {
1577 #if defined(__gfx12__)
1579 #elif defined(__gfx11__)
1581 #else
1583 #endif
1584  }
1585 
1586  template <>
1587  constexpr auto GetMfma<bf8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
1588  {
1589 #if defined(__gfx12__)
1591 #elif defined(__gfx11__)
1593 #else
1595 #endif
1596  }
1597 
1598  template <>
1599  constexpr auto GetMfma<f6_t, 32, 32, f6_t, is_single_rate_mfma, true>()
1600  {
1602  }
1603  template <>
1604  constexpr auto GetMfma<f6_t, 16, 16, f6_t, is_single_rate_mfma, true>()
1605  {
1606 #if defined(__gfx12__)
1608 #elif defined(__gfx11__)
1610 #else
1612 #endif
1613  }
1614  template <>
1615  constexpr auto GetMfma<bf6_t, 32, 32, bf6_t, is_single_rate_mfma, true>()
1616  {
1618  }
1619  template <>
1620  constexpr auto GetMfma<bf6_t, 16, 16, bf6_t, is_single_rate_mfma, true>()
1621  {
1622 #if defined(__gfx12__)
1624 #elif defined(__gfx11__)
1626 #else
1628 #endif
1629  }
1630 
1631  template <>
1632  constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
1633  {
1635  }
1636 
1637  template <>
1638  constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, false, false>()
1639  {
1640 #if defined(__gfx950__)
1642 #else
1644 #endif
1645  }
1646 
1647  template <>
1648  constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
1649  {
1650 #if defined(__gfx12__)
1652 #elif defined(__gfx11__)
1654 #else
1656 #endif
1657  }
1658 
1659  template <>
1660  constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
1661  {
1662 #if defined(__gfx12__)
1664 #elif defined(__gfx11__)
1666 #elif defined(__gfx950__)
1668 #else
1670 #endif
1671  }
1672 
1673  template <>
1674  constexpr auto GetMfma<f8_t, 32, 32, bf8_t, true, false>()
1675  {
1677  }
1678 
1679  template <>
1680  constexpr auto GetMfma<f8_t, 32, 32, bf8_t, false, false>()
1681  {
1682 #if defined(__gfx950__)
1684 #else
1686 #endif
1687  }
1688 
1689  template <>
1690  constexpr auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
1691  {
1692 #if defined(__gfx12__)
1694 #elif defined(__gfx11__)
1696 #else
1698 #endif
1699  }
1700 
1701  template <>
1702  constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
1703  {
1704 #if defined(__gfx12__)
1706 #elif defined(__gfx11__)
1708 #elif defined(__gfx950__)
1710 #else
1712 #endif
1713  }
1714 
1715  template <>
1716  constexpr auto GetMfma<bf8_t, 32, 32, f8_t, true, false>()
1717  {
1719  }
1720 
1721  template <>
1722  constexpr auto GetMfma<bf8_t, 32, 32, f8_t, false, false>()
1723  {
1724 #if defined(__gfx950__)
1726 #else
1728 #endif
1729  }
1730 
1731  template <>
1732  constexpr auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
1733  {
1734 #if defined(__gfx12__)
1736 #elif defined(__gfx11__)
1738 #else
1740 #endif
1741  }
1742 
1743  template <>
1744  constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
1745  {
1746 #if defined(__gfx12__)
1748 #elif defined(__gfx11__)
1750 #elif defined(__gfx950__)
1752 #else
1754 #endif
1755  }
1756 
1758  MPerXdlops,
1759  NPerXdlops,
1761  is_single_rate_mfma,
1762  is_scale_mfma>()>{};
1763 
1764  __host__ __device__ constexpr MfmaSelector()
1765  {
1766  static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
1767  selected_mfma.num_regs_per_blk,
1768  "wrong! num_regs_per_blk");
1769 
1770  static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk,
1771  "n_per_blk != num_threads_per_blk");
1772 #if defined(__gfx11__)
1773  if constexpr(MPerXdlops == 16 && NPerXdlops == 16)
1774  {
1775  static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks * 2 ==
1776  selected_mfma.m_per_blk,
1777  "m_per_blk != num_input_blks * num_regs_per_blk");
1778  }
1779 #else
1780  static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks ==
1781  selected_mfma.m_per_blk,
1782  "m_per_blk != num_input_blks * num_regs_per_blk");
1783 #endif
1784 
1785  static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks ||
1786  selected_mfma.num_output_blks == 1,
1787  "incorrect num_output_blks");
1788 
1789  static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size ==
1790  selected_mfma.m_per_blk * selected_mfma.n_per_blk,
1791  "num_regs_per_blk incorrect");
1792 
1793  static_assert(selected_mfma.is_k_reduction ||
1794  (selected_mfma.num_input_blks == selected_mfma.num_output_blks),
1795  "is_k_reduction wrong!");
1796  }
1797 
1798  static constexpr bool IsABroadcast()
1799  {
1800  static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
1801  return true;
1802  }
1803 
1804  static constexpr index_t GetKPerXdlops()
1805  {
1806  return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) *
1807  selected_mfma.k_per_blk;
1808  }
1809 
1810  static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; }
1811 };
1812 
1813 template <typename base_type,
1814  index_t MPerXdlops,
1815  index_t NPerXdlops,
1816  index_t KPack,
1817  typename additional_type = base_type,
1818  bool TransposeC = false,
1819  bool is_scale_mfma = false>
1821 {
1822  static constexpr auto I0 = Number<0>{};
1823  static constexpr auto I1 = Number<1>{};
1824  static constexpr auto I2 = Number<2>{};
1825  static constexpr auto I3 = Number<3>{};
1826  static constexpr auto I4 = Number<4>{};
1827  static constexpr auto I5 = Number<5>{};
1828 
1831 
1832  __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; }
1833 
1834  __device__ static constexpr index_t GetNumXdlops()
1835  {
1836  return MPerXdlops * NPerXdlops /
1837  (mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks);
1838  }
1839 
1840  __host__ __device__ constexpr XdlopsGemm()
1841  {
1842  static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
1843  NPerXdlops == 64,
1844  "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1845 
1846  static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
1847  MPerXdlops == 64,
1848  "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1849 #if defined(__HIP_DEVICE_COMPILE__)
1850  static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
1851 #endif
1852  }
1853 
1854  // XDL output supporting C = A * B
1855  // M2_N2 -> M2_M3_M4_N2
1856  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1857  __host__ __device__ static constexpr auto
1858  MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1859  {
1860  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1861  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1862  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1863  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1864  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1865 
1867  c_desc_m0_n0_m1_n1_m2_n2,
1873  Number<num_blks>{},
1874  Number<mfma_instr.group_size>{})),
1877  Sequence<1>{},
1878  Sequence<2>{},
1879  Sequence<3>{},
1880  Sequence<4>{},
1881  Sequence<5>{}),
1883  Sequence<1>{},
1884  Sequence<2>{},
1885  Sequence<3>{},
1887  Sequence<7>{}));
1888  }
1889 
1890  // XDL output supporting C = A * B
1891  // M3_N3 -> M3_M4_M5_N3
1892  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1893  __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
1894  const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1895  {
1896  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1897  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1898  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1899  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1900  const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
1901  const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
1902  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1903 
1905  c_desc_m0_n0_m1_n1_m2_n2,
1913  Number<num_blks>{},
1914  Number<mfma_instr.group_size>{})),
1917  Sequence<1>{},
1918  Sequence<2>{},
1919  Sequence<3>{},
1920  Sequence<4>{},
1921  Sequence<5>{},
1922  Sequence<6>{},
1923  Sequence<7>{}),
1925  Sequence<1>{},
1926  Sequence<2>{},
1927  Sequence<3>{},
1928  Sequence<4>{},
1929  Sequence<5>{},
1931  Sequence<9>{}));
1932  }
1933 
1934  // transposed XDL output supporting C' = B' * A'
1935  // M2_N2 -> M2_N2_N3_N4
1936  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1937  __host__ __device__ static constexpr auto
1938  MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1939  {
1940  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1941  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1942  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1943  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1944  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1945 
1947  c_desc_m0_n0_m1_n1_m2_n2,
1954  Number<num_blks>{},
1955  Number<mfma_instr.group_size>{}))),
1957  Sequence<1>{},
1958  Sequence<2>{},
1959  Sequence<3>{},
1960  Sequence<4>{},
1961  Sequence<5>{}),
1963  Sequence<1>{},
1964  Sequence<2>{},
1965  Sequence<3>{},
1966  Sequence<4>{},
1967  Sequence<5, 6, 7>{}));
1968  }
1969 
1970  template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
1971  __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
1972  const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
1973  {
1974  const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
1975  const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
1976  const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
1977  const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
1978  const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
1979  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1980 
1982  c_desc_g_m0_n0_m1_n1_m2_n2,
1989  mfma_instr.num_groups_per_blk, num_blks, mfma_instr.group_size)),
1990  make_pass_through_transform(mfma_instr.num_threads_per_blk)),
1992  Sequence<1>{},
1993  Sequence<2>{},
1994  Sequence<3>{},
1995  Sequence<4>{},
1996  Sequence<5>{},
1997  Sequence<6>{}),
1999  Sequence<1>{},
2000  Sequence<2>{},
2001  Sequence<3>{},
2002  Sequence<4>{},
2004  Sequence<8>{}));
2005  }
2006 
2007  __device__ __host__ static constexpr index_t GetRegSizePerXdlops()
2008  {
2009  return mfma_instr.num_regs_per_blk;
2010  }
2011 
2012  __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
2013 
2014  template <class FloatA, class FloatB, class FloatC>
2015  __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
2016  {
2017  static_assert(
2024  "base_type must be double, float, tf32_t, half, bfloat16, int8_t, f8_t or bf8_t!");
2025 
2026  static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
2027  if constexpr(!TransposeC)
2028  {
2029  mfma_instr.template run<MPerXdlops, NPerXdlops>(
2030  p_a_wave[k], p_b_wave[k], p_c_thread);
2031  }
2032  else
2033  {
2034  mfma_instr.template run<MPerXdlops, NPerXdlops>(
2035  p_b_wave[k], p_a_wave[k], p_c_thread);
2036  }
2037  });
2038  }
2039 
2040  template <index_t OpselA,
2041  index_t OpselB,
2042  class FloatA,
2043  class ScaleA,
2044  class FloatB,
2045  class ScaleB,
2046  class FloatC>
2047  __device__ void Run(const FloatA& p_a_wave,
2048  const ScaleA& a_scale_thread,
2049  const FloatB& p_b_wave,
2050  const ScaleB& b_scale_thread,
2051  FloatC& p_c_thread) const
2052  {
2053  static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
2054  if constexpr(!TransposeC)
2055  {
2056  mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
2057  p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
2058  }
2059  else
2060  {
2061  mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
2062  p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
2063  }
2064  });
2065  }
2066 
2067  __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; }
2068 
2069  __device__ static auto GetBlkIdx()
2070  {
2071  const auto laneId = GetLaneId();
2072  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
2073 
2074  constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
2075  make_tuple(
2076  make_merge_transform(make_tuple(1, num_blks, mfma_instr.num_threads_per_blk))),
2078  make_tuple(Sequence<0>{}));
2079 
2080  const auto blk_idx =
2081  threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
2082 
2083  const auto blk_id = blk_idx[I1];
2084  const auto blk_td = blk_idx[I2];
2085 
2086  return make_tuple(blk_id, blk_td);
2087  }
2088 
2089  template <bool SwizzleA>
2090  __device__ static auto GetGfx11InputBlkIdx()
2091  {
2092  auto laneId = GetLaneId() % mfma_instr.num_threads_per_blk;
2093  if constexpr(SwizzleA)
2094  {
2095  laneId = ((laneId & 1) << 3) | (laneId >> 1);
2096  }
2097  constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
2099  make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))),
2101  make_tuple(Sequence<0>{}));
2102 
2103  const auto blk_idx =
2104  threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
2105 
2106  const auto blk_id = blk_idx[I1];
2107  const auto blk_td = blk_idx[I2];
2108 
2109  return make_tuple(blk_id, blk_td);
2110  }
2111 
2112  __host__ __device__ static auto CalculateAThreadOriginDataIndex()
2113  {
2114  const auto laneId = GetLaneId();
2115 #if defined(__gfx11__)
2116  const auto blk_idx = GetGfx11InputBlkIdx<!TransposeC>();
2117 #else
2118  const auto blk_idx = GetBlkIdx();
2119 #endif
2120 
2121  const auto blk_id = blk_idx[I0];
2122  const auto blk_td = blk_idx[I1];
2123 
2124  if constexpr(mfma_instr.is_k_reduction)
2125  {
2126  return make_tuple(blk_id, blk_td);
2127  }
2128  else
2129  {
2130  return make_tuple(0, laneId);
2131  }
2132  }
2133 
2134  __host__ __device__ static auto CalculateBThreadOriginDataIndex()
2135  {
2136  const auto laneId = GetLaneId();
2137 #if defined(__gfx11__)
2138  const auto blk_idx = GetGfx11InputBlkIdx<TransposeC>();
2139 #else
2140  const auto blk_idx = GetBlkIdx();
2141 #endif
2142 
2143  const auto blk_id = blk_idx[I0];
2144  const auto blk_td = blk_idx[I1];
2145 
2146  if constexpr(mfma_instr.is_k_reduction)
2147  {
2148  return make_tuple(blk_id, blk_td);
2149  }
2150  else
2151  {
2152  return make_tuple(0, laneId);
2153  }
2154  }
2155 
2156  __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
2157  {
2158  const auto blk_idx = GetBlkIdx();
2159 
2160  const auto blk_id = blk_idx[I0];
2161  const auto blk_td = blk_idx[I1];
2162 
2163  index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td;
2164  index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size;
2165 
2166  return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
2167  }
2168 
2169  __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
2170  {
2171  const auto blk_idx = GetBlkIdx();
2172 
2173  const auto blk_id = blk_idx[I0];
2174  const auto blk_td = blk_idx[I1];
2175 
2176  return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
2177  }
2178 
2179  // Falls back to single rate instruction on gfx950 if KPack is single rate; no change on gfx942-
2180  // when base_type is either f8_t or bf8_t, additional_type will always be either f8_t or bf8_t,
2181  // except Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
2182  static constexpr bool is_single_rate_mfma =
2184  KPack <= 4) ||
2185  (is_same<base_type, int8_t>::value && KPack <= 8) ||
2188  ? true
2189  : false;
2190  static constexpr auto mfma = MfmaSelector<base_type,
2191  MPerXdlops,
2192  NPerXdlops,
2193  additional_type,
2195  is_scale_mfma>{};
2196 
2197  static constexpr auto mfma_instr = mfma.selected_mfma;
2198 
2199  static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
2200  static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
2201  static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
2202 
2203  __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
2204  {
2205  return make_tuple(
2207  }
2208 };
2209 
2210 } // namespace ck
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
MfmaInstr
Definition: xdlops_gemm.hpp:45
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_unsupport_16x16_gfx11
@ wmma_i32_16x16x16_iu8_gfx12
@ mfma_scale_f32_32x32x64f8f6f4
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_unsupport_16x16_gfx12
@ mfma_f32_16x16x16bf16_1k
@ wmma_f32_16x16x16_f8f8_gfx12
@ mfma_scale_f32_16x16x128f8f6f4
@ mfma_f32_16x16x128f8f6f4
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
typename packed_type_info< T >::element_type element_type_t
Definition: data_type.hpp:408
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_i32_16x16x16_iu8_gfx12
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_f32_16x16x16_f8f8_gfx12
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: array.hpp:14
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
__host__ constexpr __device__ MfmaSelector()
Definition: xdlops_gemm.hpp:1764
static constexpr bool IsABroadcast()
Definition: xdlops_gemm.hpp:1798
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1810
static constexpr auto GetMfma()
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1804
Definition: sequence.hpp:43
Definition: xdlops_gemm.hpp:1821
static constexpr auto mfma_instr
Definition: xdlops_gemm.hpp:2197
__host__ constexpr __device__ XdlopsGemm()
Definition: xdlops_gemm.hpp:1840
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:2134
static __device__ auto GetBlkIdx()
Definition: xdlops_gemm.hpp:2069
__device__ static constexpr __host__ index_t GetRegSizePerXdlops()
Definition: xdlops_gemm.hpp:2007
static constexpr auto I2
Definition: xdlops_gemm.hpp:1824
static constexpr __device__ index_t GetNumBlks()
Definition: xdlops_gemm.hpp:1832
static __device__ auto GetLaneId()
Definition: xdlops_gemm.hpp:2067
static constexpr auto K0PerXdlops
Definition: xdlops_gemm.hpp:2201
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1893
static constexpr __device__ index_t GetNumXdlops()
Definition: xdlops_gemm.hpp:1834
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:2112
static constexpr bool is_single_rate_mfma
Definition: xdlops_gemm.hpp:2182
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition: xdlops_gemm.hpp:2169
static constexpr __device__ index_t GetWaveSize()
Definition: xdlops_gemm.hpp:2012
static __device__ auto GetGfx11InputBlkIdx()
Definition: xdlops_gemm.hpp:2090
static constexpr auto I5
Definition: xdlops_gemm.hpp:1827
static constexpr auto I3
Definition: xdlops_gemm.hpp:1825
static constexpr auto I0
Definition: xdlops_gemm.hpp:1822
__device__ void Run(const FloatA &p_a_wave, const ScaleA &a_scale_thread, const FloatB &p_b_wave, const ScaleB &b_scale_thread, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:2047
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1858
__host__ static constexpr __device__ auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1971
static constexpr auto I1
Definition: xdlops_gemm.hpp:1823
static constexpr auto K1PerXdlops
Definition: xdlops_gemm.hpp:2200
static constexpr auto KPerXdlops
Definition: xdlops_gemm.hpp:2199
static constexpr auto I4
Definition: xdlops_gemm.hpp:1826
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:2015
static constexpr auto mfma
Definition: xdlops_gemm.hpp:2190
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition: xdlops_gemm.hpp:2156
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1938
__host__ static constexpr __device__ auto GetCM0M1M2NThreadBlkLengths()
Definition: xdlops_gemm.hpp:2203
Definition: integral_constant.hpp:20
Definition: amd_xdlops.hpp:1202
Definition: amd_xdlops.hpp:303
Definition: amd_xdlops.hpp:193
Definition: amd_xdlops.hpp:70
Definition: amd_xdlops.hpp:269
Definition: amd_xdlops.hpp:1483
Definition: amd_xdlops.hpp:1609
Definition: amd_xdlops.hpp:159
Definition: amd_xdlops.hpp:1546
Definition: amd_xdlops.hpp:1420
Definition: amd_xdlops.hpp:207
Definition: amd_xdlops.hpp:56
Definition: amd_xdlops.hpp:331
Definition: amd_xdlops.hpp:1641
Definition: amd_xdlops.hpp:249
Definition: amd_xdlops.hpp:1451
Definition: amd_xdlops.hpp:1577
Definition: amd_xdlops.hpp:139
Definition: amd_xdlops.hpp:1514
Definition: amd_xdlops.hpp:1388
Definition: amd_xdlops.hpp:15
Definition: amd_xdlops.hpp:42
Definition: amd_xdlops.hpp:317
Definition: amd_xdlops.hpp:112
Definition: amd_xdlops.hpp:1661
Definition: amd_xdlops.hpp:481
Definition: amd_xdlops.hpp:289
Definition: amd_xdlops.hpp:179
Definition: amd_xdlops.hpp:84
Definition: amd_xdlops.hpp:221
Definition: amd_xdlops.hpp:461
Definition: amd_xdlops.hpp:364
Definition: amd_xdlops.hpp:442
Definition: amd_xdlops.hpp:403
Definition: amd_xdlops.hpp:423
Definition: amd_xdlops.hpp:383
Definition: amd_xdlops.hpp:345
Definition: amd_xdlops.hpp:886
Definition: amd_xdlops.hpp:666
Definition: amd_wmma.hpp:50
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
Definition: type.hpp:177
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:873
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:451
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:319
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:186
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:429
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:737
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:825
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:297
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:781
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:693
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:341
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:164
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:495
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:990
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:385
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:715
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:803
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:275
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:759
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:671
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:120
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:142
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:473
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:231
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1012
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:849
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:407
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:253
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:209
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:363
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:649
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:539
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:583
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:627
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:561
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:605
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:517
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:942
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:905
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1048
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1112
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1170
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1160
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1038
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1102
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1150
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1140
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1065
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1129
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition: xdlops_gemm.hpp:1076
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition: xdlops_gemm.hpp:1181
Definition: xdlops_gemm.hpp:1020
static constexpr index_t n_per_blk
Definition: xdlops_gemm.hpp:1029
static constexpr index_t group_size
Definition: xdlops_gemm.hpp:1021
static constexpr index_t m_per_blk
Definition: xdlops_gemm.hpp:1028
static constexpr bool is_k_reduction
Definition: xdlops_gemm.hpp:1031
static constexpr index_t num_threads_per_blk
Definition: xdlops_gemm.hpp:1024
static constexpr index_t num_output_blks
Definition: xdlops_gemm.hpp:1027
static constexpr index_t wave_size
Definition: xdlops_gemm.hpp:1025
static constexpr index_t num_input_blks
Definition: xdlops_gemm.hpp:1026
static constexpr index_t num_groups_per_blk
Definition: xdlops_gemm.hpp:1022
static constexpr index_t num_regs_per_blk
Definition: xdlops_gemm.hpp:1023
static constexpr index_t k_per_blk
Definition: xdlops_gemm.hpp:1030
Definition: xdlops_gemm.hpp:1084
static constexpr index_t n_per_blk
Definition: xdlops_gemm.hpp:1093
static constexpr index_t group_size
Definition: xdlops_gemm.hpp:1085
static constexpr index_t num_output_blks
Definition: xdlops_gemm.hpp:1091
static constexpr index_t m_per_blk
Definition: xdlops_gemm.hpp:1092
static constexpr index_t num_threads_per_blk
Definition: xdlops_gemm.hpp:1088
static constexpr bool is_k_reduction
Definition: xdlops_gemm.hpp:1095
static constexpr index_t num_regs_per_blk
Definition: xdlops_gemm.hpp:1087
static constexpr index_t num_groups_per_blk
Definition: xdlops_gemm.hpp:1086
static constexpr index_t num_input_blks
Definition: xdlops_gemm.hpp:1090
static constexpr index_t wave_size
Definition: xdlops_gemm.hpp:1089
static constexpr index_t k_per_blk
Definition: xdlops_gemm.hpp:1094
Definition: xdlops_gemm.hpp:102
Definition: functional2.hpp:33