/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp Source File
wmma_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math.hpp"
9 
10 namespace ck {
11 
12 enum struct WmmaInstr
13 {
14  // gfx11
21  // gfx12
29 };
30 
31 /*
32  * WMMA Wave Tile Always MxNxK = 16x16x16
33  * WAVE32
34  -----------------------------------
35  |RC0| | | | | | | | | | | | | | | | SubGroup 0
36  |RC1| | | | | | | | | | | | | | | |
37  |RC2| | | | | | | | | | | | | | | |
38  |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
39  |RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
40  |RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
41  |RC6| | | | | | | | | | | | | | | |
42  |RC7| | | | | | | | | | | | | | | |
43  -----------------------------------
44  | | | | | | | | | | | | | | | | | SubGroup 1
45  | | | | | | | | | | | | | | | | |
46  | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
47  | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
48  | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
49  | | | | | | | | | | | | | | | | |
50  | | | | | | | | | | | | | | | | |
51  | | | | | | | | | | | | | | | | |
52  -----------------------------------
53 
54 
55  * WAVE64
56  -----------------------------------
57  |RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0
58  |RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
59  |RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
60  |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
61  -----------------------------------
62  | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1
63  | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
64  | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
65  | | | | | | | | | | | | | | | | |
66  -----------------------------------
67  | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2
68  | 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4|
69  | 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7|
70  | | | | | | | | | | | | | | | | |
71  -----------------------------------
72  | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3
73  | 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6|
74  | 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3|
75  | | | | | | | | | | | | | | | | |
76  -----------------------------------
77 
78 * RC = Register for storing accumalted result
79 * T = Thread ID
80 */
81 
82 template <WmmaInstr Instr, index_t WaveSize, typename = void>
83 struct wmma_type
84 {
85 };
86 
87 // A-swizzled
88 template <index_t WaveSize>
90  WaveSize,
91  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
92 {
93  // Absolute fixing property
94  // * Data Pixel
95  static constexpr index_t m_per_wmma = 16;
96  static constexpr index_t n_per_wmma = 16;
97  static constexpr index_t k_per_wmma = 16;
98  static constexpr index_t src_a_data_size = 2;
99  static constexpr index_t src_b_data_size = 2;
100  static constexpr index_t acc_data_size = 4;
101  static constexpr index_t acc_pack_number = 1;
102  // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
103  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
104 
105  // Wave mode dependent propety
106  static constexpr index_t wave_size = Number<WaveSize>{};
107  // * Fixed on gfx11, Will be wave mode dependent for future architectures
108  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
109  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
110  // * num_acc_vgprs_per_wave alone M direction
111  // * num_subgroups alone M direction
112  static constexpr index_t num_acc_vgprs_per_wave =
113  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
114  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
115 
116  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
117  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
118  {
119  if constexpr(wave_size == 32)
120  {
122  }
123  else if constexpr(wave_size == 64)
124  {
126  }
127  }
128 };
129 
130 template <index_t WaveSize>
132  WaveSize,
133  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
134 {
135  // Absolute fixing property
136  static constexpr index_t m_per_wmma = 16;
137  static constexpr index_t n_per_wmma = 16;
138  static constexpr index_t k_per_wmma = 16;
139  static constexpr index_t src_a_data_size = 2;
140  static constexpr index_t src_b_data_size = 2;
141  static constexpr index_t acc_data_size = 4;
142  static constexpr index_t acc_pack_number = 1;
143  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
144 
145  // Wave mode dependent propety
146  static constexpr index_t wave_size = Number<WaveSize>{};
147  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
148  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
149  static constexpr index_t num_acc_vgprs_per_wave =
150  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
151  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
152 
153  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
154  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
155  {
156  if constexpr(wave_size == 32)
157  {
159  }
160  else if constexpr(wave_size == 64)
161  {
163  }
164  }
165 };
166 
167 template <index_t WaveSize>
169  WaveSize,
170  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
171 {
172  // Absolute fixing property
173  static constexpr index_t m_per_wmma = 16;
174  static constexpr index_t n_per_wmma = 16;
175  static constexpr index_t k_per_wmma = 16;
176  static constexpr index_t src_a_data_size = 2;
177  static constexpr index_t src_b_data_size = 2;
178  static constexpr index_t acc_data_size = 2;
179  static constexpr index_t acc_pack_number = 2;
180  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
181 
182  // Wave mode dependent propety
183  static constexpr index_t wave_size = Number<WaveSize>{};
184  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
185  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
186  static constexpr index_t num_acc_vgprs_per_wave =
187  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
188  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
189 
190  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
191  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
192  {
193  if constexpr(wave_size == 32)
194  {
196  }
197  else if constexpr(wave_size == 64)
198  {
200  }
201  }
202 };
203 template <index_t WaveSize>
205  WaveSize,
206  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
207 {
208  // Absolute fixing property
209  static constexpr index_t m_per_wmma = 16;
210  static constexpr index_t n_per_wmma = 16;
211  static constexpr index_t k_per_wmma = 16;
212  static constexpr index_t src_a_data_size = 2;
213  static constexpr index_t src_b_data_size = 2;
214  static constexpr index_t acc_data_size = 2;
215  static constexpr index_t acc_pack_number = 2;
216  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
217 
218  // Wave mode dependent propety
219  static constexpr index_t wave_size = Number<WaveSize>{};
220  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
221  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
222  static constexpr index_t num_acc_vgprs_per_wave =
223  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
224  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
225 
226  template <index_t MPerWmma,
227  index_t NPerWmma,
228  index_t Opsel,
229  class FloatA,
230  class FloatB,
231  class FloatC>
232  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
233  {
234  if constexpr(wave_size == 32)
235  {
237  }
238  else if constexpr(wave_size == 64)
239  {
241  }
242  }
243 };
244 
245 template <index_t WaveSize>
247  WaveSize,
248  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
249 {
250  // Absolute fixing property
251  static constexpr index_t m_per_wmma = 16;
252  static constexpr index_t n_per_wmma = 16;
253  static constexpr index_t k_per_wmma = 16;
254  static constexpr index_t src_a_data_size = 2;
255  static constexpr index_t src_b_data_size = 2;
256  static constexpr index_t acc_data_size = 4;
257  static constexpr index_t acc_pack_number = 1;
258  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
259 
260  // Wave mode dependent propety
261  static constexpr index_t wave_size = Number<WaveSize>{};
262  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
263  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
264  static constexpr index_t num_acc_vgprs_per_wave =
265  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
266  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
267 
268  template <index_t MPerWmma,
269  index_t NPerWmma,
270  class FloatA,
271  class FloatB,
272  class FloatC,
273  bool neg_a = false,
274  bool neg_b = false,
275  bool clamp = false>
276  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
277  {
278  if constexpr(wave_size == 32)
279  {
281  a, b, reg_c);
282  }
283  else if constexpr(wave_size == 64)
284  {
286  a, b, reg_c);
287  }
288  }
289 };
290 
291 // gfx12
292 
293 // A-swizzled
294 template <index_t WaveSize>
296  WaveSize,
297  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
298 {
299  // Absolute fixing property
300  // * Data Pixel
301  static constexpr index_t m_per_wmma = 16;
302  static constexpr index_t n_per_wmma = 16;
303  static constexpr index_t k_per_wmma = 16;
304  // static constexpr index_t src_a_data_size = 2;
305  // static constexpr index_t src_b_data_size = 2;
306  // static constexpr index_t acc_data_size = 4;
307  // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
308  static constexpr index_t acc_data_size = 4;
309  static constexpr index_t acc_pack_number = 1;
310  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
311 
312  // Wave mode dependent propety
313  static constexpr index_t wave_size = Number<WaveSize>{};
314  // * Fixed for gfx11, Will be wave mode dependent on gfx12
315  // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
316  // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
317  // * num_acc_vgprs_per_wave alone M direction
318  // * num_subgroups alone M direction
319  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
320  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
321 
322  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
323  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
324  {
325  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
326  if constexpr(wave_size == 32)
327  {
329  }
330  }
331 };
332 
333 template <index_t WaveSize>
335  WaveSize,
336  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
337 {
338  // Absolute fixing property
339  static constexpr index_t m_per_wmma = 16;
340  static constexpr index_t n_per_wmma = 16;
341  static constexpr index_t k_per_wmma = 16;
342  // static constexpr index_t src_a_data_size = 2;
343  // static constexpr index_t src_b_data_size = 2;
344  static constexpr index_t acc_data_size = 4;
345  static constexpr index_t acc_pack_number = 1;
346  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
347 
348  // Wave mode dependent propety
349  static constexpr index_t wave_size = Number<WaveSize>{};
350  // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
351  // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
352  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
353  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
354 
355  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
356  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
357  {
358  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
359  if constexpr(wave_size == 32)
360  {
362  }
363  }
364 };
365 
366 template <index_t WaveSize>
368  WaveSize,
369  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
370 {
371  // Absolute fixing property
372  static constexpr index_t m_per_wmma = 16;
373  static constexpr index_t n_per_wmma = 16;
374  static constexpr index_t k_per_wmma = 16;
375  // static constexpr index_t src_a_data_size = 2;
376  // static constexpr index_t src_b_data_size = 2;
377  static constexpr index_t acc_data_size = 4;
378  static constexpr index_t acc_pack_number = 1;
379  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
380 
381  // Wave mode dependent propety
382  static constexpr index_t wave_size = Number<WaveSize>{};
383  // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
384  // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
385  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
386  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
387 
388  template <index_t MPerWmma,
389  index_t NPerWmma,
390  class FloatA,
391  class FloatB,
392  class FloatC,
393  bool neg_a = false,
394  bool neg_b = false,
395  bool clamp = false>
396  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
397  {
398  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
399  if constexpr(wave_size == 32)
400  {
402  a, b, reg_c);
403  }
404  }
405 };
406 
407 template <index_t WaveSize>
409  WaveSize,
410  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
411 {
412  // Absolute fixing property
413  static constexpr index_t m_per_wmma = 16;
414  static constexpr index_t n_per_wmma = 16;
415  static constexpr index_t k_per_wmma = 16;
416  static constexpr index_t acc_data_size = 4;
417  static constexpr index_t acc_pack_number = 1;
418  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
419 
420  // Wave mode dependent propety
421  static constexpr index_t wave_size = Number<WaveSize>{};
422  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
423  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
424 
425  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
426  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
427  {
428  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
429  if constexpr(wave_size == 32)
430  {
431 #ifdef __gfx12__
433 #else
434  ignore = a;
435  ignore = b;
436  ignore = reg_c;
437 #endif
438  }
439  }
440 };
441 
442 template <index_t WaveSize>
444  WaveSize,
445  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
446 {
447  // Absolute fixing property
448  static constexpr index_t m_per_wmma = 16;
449  static constexpr index_t n_per_wmma = 16;
450  static constexpr index_t k_per_wmma = 16;
451  static constexpr index_t acc_data_size = 4;
452  static constexpr index_t acc_pack_number = 1;
453  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
454 
455  // Wave mode dependent propety
456  static constexpr index_t wave_size = Number<WaveSize>{};
457  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
458  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
459 
460  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
461  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
462  {
463  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
464  if constexpr(wave_size == 32)
465  {
466 #ifdef __gfx12__
468 #else
469  ignore = a;
470  ignore = b;
471  ignore = reg_c;
472 #endif
473  }
474  }
475 };
476 
477 template <index_t WaveSize>
479  WaveSize,
480  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
481 {
482  // Absolute fixing property
483  static constexpr index_t m_per_wmma = 16;
484  static constexpr index_t n_per_wmma = 16;
485  static constexpr index_t k_per_wmma = 16;
486  static constexpr index_t acc_data_size = 4;
487  static constexpr index_t acc_pack_number = 1;
488  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
489 
490  // Wave mode dependent propety
491  static constexpr index_t wave_size = Number<WaveSize>{};
492  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
493  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
494 
495  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
496  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
497  {
498  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
499  if constexpr(wave_size == 32)
500  {
501 #ifdef __gfx12__
503 #else
504  ignore = a;
505  ignore = b;
506  ignore = reg_c;
507 #endif
508  }
509  }
510 };
511 
512 template <index_t WaveSize>
514  WaveSize,
515  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
516 {
517  // Absolute fixing property
518  static constexpr index_t m_per_wmma = 16;
519  static constexpr index_t n_per_wmma = 16;
520  static constexpr index_t k_per_wmma = 16;
521  static constexpr index_t acc_data_size = 4;
522  static constexpr index_t acc_pack_number = 1;
523  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
524 
525  // Wave mode dependent propety
526  static constexpr index_t wave_size = Number<WaveSize>{};
527  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
528  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
529 
530  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
531  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
532  {
533  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
534  if constexpr(wave_size == 32)
535  {
536 #ifdef __gfx12__
538 #else
539  ignore = a;
540  ignore = b;
541  ignore = reg_c;
542 #endif
543  }
544  }
545 };
546 
547 template <typename src_type_a,
548  typename src_type_b,
549  typename dst_type,
550  index_t MPerWmma,
551  index_t NPerWmma>
553 {
554  template <typename src_type_a_,
555  typename src_type_b_,
556  typename dst_type_,
557  index_t MPerWmma_,
558  index_t NPerWmma_>
559  static constexpr auto GetWmma();
560 
561  template <>
562  constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
563  {
564 #ifdef __gfx12__
566 #else
568 #endif
569  }
570 
571  template <>
572  constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
573  {
574 #ifdef __gfx12__
576 #else
578 #endif
579  }
580 
581  template <>
582  constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
583  {
585  }
586 
587  template <>
588  constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
589  {
591  }
592 
593  template <>
594  constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
595  {
596 #ifdef __gfx12__
598 #else
600 #endif
601  }
602 
603 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
604  template <>
605  constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
606  {
608  }
609 #endif
610 
611  template <>
612  constexpr auto GetWmma<f8_t, f8_t, float, 16, 16>()
613  {
615  }
616 
617  template <>
618  constexpr auto GetWmma<f8_t, bf8_t, float, 16, 16>()
619  {
621  }
622 
623  template <>
624  constexpr auto GetWmma<bf8_t, f8_t, float, 16, 16>()
625  {
627  }
628 
629  template <>
630  constexpr auto GetWmma<bf8_t, bf8_t, float, 16, 16>()
631  {
633  }
634 
635  // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
636  static constexpr auto selected_wmma =
638 
639  __host__ __device__ constexpr WmmaSelector()
640  {
641  static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
642 
643  static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
644 
645  static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
646 
647  static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
648  selected_wmma.acc_data_size * selected_wmma.acc_pack_number ==
649  selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
650  "WRONG! Invalid Number of Accumulator Register");
651  }
652 };
653 
654 template <typename src_type_a,
655  typename src_type_b,
656  typename dst_type,
657  index_t MPerWmma,
658  index_t NPerWmma,
659  index_t KPack,
660  bool TransposeC = false,
661  bool AssemblyBackend = false>
662 struct WmmaGemm
663 {
664  static constexpr auto I0 = Number<0>{};
665  static constexpr auto I1 = Number<1>{};
666  static constexpr auto I2 = Number<2>{};
667  static constexpr auto I3 = Number<3>{};
668  static constexpr auto I4 = Number<4>{};
669  static constexpr auto I5 = Number<5>{};
670 
673 
674  __host__ __device__ constexpr WmmaGemm()
675  {
676  static_assert(NPerWmma == 16 && MPerWmma == 16,
677  "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
678 
679  static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma");
680  }
681 
682  // WMMA output supporting C = A * B
683  // Vector Write
684  // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
685  template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
686  __host__ __device__ static constexpr auto
688  const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
689  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
690  {
691  const auto MBlockxRepeat =
692  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
693  const auto NBlockxRepeat =
694  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
695  const auto MWave =
696  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
697  const auto NWave =
698  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
699 
701  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
702  make_tuple(
703  make_pass_through_transform(MBlockxRepeat),
706  Number<wmma_instr.num_acc_vgprs_per_wave>{})),
707  make_pass_through_transform(NBlockxRepeat),
711  Sequence<1>{},
712  Sequence<2>{},
713  Sequence<3>{},
714  Sequence<4>{},
715  Sequence<5>{}),
717  Sequence<1>{},
718  Sequence<2, 6>{},
719  Sequence<3>{},
720  Sequence<4>{},
721  Sequence<5>{}));
722  }
723 
724  // Transposed WMMA Output C' = B' * A'
725  template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
726  __host__ __device__ static constexpr auto
728  const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
729  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
730  {
731  const auto MBlockxRepeat =
732  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
733  const auto NBlockxRepeat =
734  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
735  const auto MWave =
736  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
737  const auto NWave =
738  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
739 
741  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
742  make_tuple(
743  make_pass_through_transform(MBlockxRepeat),
746  make_pass_through_transform(NBlockxRepeat),
749  Number<wmma_instr.num_acc_vgprs_per_wave>{}))),
751  Sequence<1>{},
752  Sequence<2>{},
753  Sequence<3>{},
754  Sequence<4>{},
755  Sequence<5>{}),
757  Sequence<1>{},
758  Sequence<2>{},
759  Sequence<3>{},
760  Sequence<4>{},
761  Sequence<5, 6>{}));
762  }
763 
764  __device__ static constexpr index_t GetRegSizePerWmma()
765  {
766  return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number;
767  }
768 
769  __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
770 
771  template <class FloatA, class FloatB, class FloatC>
772  __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
773  {
774  static_assert(
788 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
791 #endif
792  false,
793  "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
794  "((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!");
795  static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
796  if constexpr(!TransposeC)
797  {
798  wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
799  }
800  else
801  {
802  wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
803  }
804  });
805  }
806 
807  __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
808 
809  __device__ static auto GetSubGroupId()
810  {
811  static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups ==
812  wmma_instr.wave_size,
813  "");
814  return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
815  }
816 
817  __device__ static auto GetLaneIdUnderSubGroup()
818  {
819  return GetLaneId() % wmma_instr.num_thread_per_subgroups;
820  }
821  __device__ static auto GetSwizzledLaneIdLow()
822  {
823  return ((GetLaneIdUnderSubGroup() & 1) << 3) | (GetLaneIdUnderSubGroup() >> 1);
824  }
825 
826  __host__ __device__ static auto CalculateAThreadOriginDataIndex()
827  {
828 #ifdef __gfx12__
829  return GetLaneIdUnderSubGroup();
830 #else
831  return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
832 #endif
833  }
834 
835  __host__ __device__ static auto CalculateBThreadOriginDataIndex()
836  {
837 #ifdef __gfx12__
838  return GetLaneIdUnderSubGroup();
839 #else
840  return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
841 #endif
842  }
843 
844  __device__ static CIndex GetBeginOfThreadBlk()
845  {
846  index_t n_offset = GetLaneIdUnderSubGroup();
847  index_t m_offset = GetSubGroupId() * wmma_instr.num_acc_vgprs_per_wave;
848 
849  return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
850  }
851 
852  __device__ static CIndex3D GetBeginOfThreadBlk3D()
853  {
854  index_t n_offset = GetLaneIdUnderSubGroup();
855  index_t m_offset = GetSubGroupId();
856 
857  return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0};
858  }
859 
860  static constexpr auto wmma =
862  static constexpr auto wmma_instr = wmma.selected_wmma;
863 
864  __host__ __device__ static constexpr auto
866  {
867  return make_tuple(I1,
868  I1,
870  Number<wmma_instr.acc_pack_number>{});
871  }
872 };
873 
874 } // namespace ck
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
Definition: ck.hpp:267
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
WmmaInstr
Definition: wmma_gemm.hpp:13
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_i32_16x16x16_iu8_gfx12
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_f32_16x16x16_f8f8_gfx12
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: array.hpp:14
Definition: sequence.hpp:43
Definition: wmma_gemm.hpp:663
static constexpr auto I0
Definition: wmma_gemm.hpp:664
static __device__ auto GetLaneId()
Definition: wmma_gemm.hpp:807
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: wmma_gemm.hpp:772
static constexpr __device__ index_t GetWaveSize()
Definition: wmma_gemm.hpp:769
static constexpr auto wmma
Definition: wmma_gemm.hpp:860
__host__ static constexpr __device__ auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
Definition: wmma_gemm.hpp:865
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: wmma_gemm.hpp:826
static __device__ auto GetSubGroupId()
Definition: wmma_gemm.hpp:809
static __device__ auto GetSwizzledLaneIdLow()
Definition: wmma_gemm.hpp:821
static constexpr auto I3
Definition: wmma_gemm.hpp:667
static constexpr auto I5
Definition: wmma_gemm.hpp:669
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: wmma_gemm.hpp:835
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:727
static __device__ CIndex GetBeginOfThreadBlk()
Definition: wmma_gemm.hpp:844
static constexpr auto I4
Definition: wmma_gemm.hpp:668
static constexpr __device__ index_t GetRegSizePerWmma()
Definition: wmma_gemm.hpp:764
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:687
__host__ constexpr __device__ WmmaGemm()
Definition: wmma_gemm.hpp:674
static constexpr auto I2
Definition: wmma_gemm.hpp:666
static __device__ CIndex3D GetBeginOfThreadBlk3D()
Definition: wmma_gemm.hpp:852
static constexpr auto I1
Definition: wmma_gemm.hpp:665
static __device__ auto GetLaneIdUnderSubGroup()
Definition: wmma_gemm.hpp:817
static constexpr auto wmma_instr
Definition: wmma_gemm.hpp:862
Definition: wmma_gemm.hpp:553
static constexpr auto selected_wmma
Definition: wmma_gemm.hpp:636
__host__ constexpr __device__ WmmaSelector()
Definition: wmma_gemm.hpp:639
static constexpr auto GetWmma()
Definition: integral_constant.hpp:20
Definition: amd_wmma.hpp:96
Definition: amd_wmma.hpp:216
Definition: amd_wmma.hpp:72
Definition: amd_wmma.hpp:192
Definition: amd_wmma.hpp:50
Definition: amd_wmma.hpp:170
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
Definition: amd_wmma.hpp:149
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
Definition: amd_wmma.hpp:241
Definition: type.hpp:177
Definition: functional2.hpp:33
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:232
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:191
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:154
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:356
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:531
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:496
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:117
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:323
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:461
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:426
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:276
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:396
Definition: wmma_gemm.hpp:84