/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File
warp_gemm_attribute_mfma_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 // TODO: refactor warp-gemm
11 // currently there is a discrepency for vav/vva if we need transpose C/D
12 // e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
13 // because we swap the A/B pointer in _impl code (but not known this info here)
14 enum class WGAttrCtlEnum
15 {
16  Default_ = 0,
17  Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
18  Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
19  Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
20  Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
21  Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
22  // raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
23 };
24 
25 #define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
26  if constexpr(post_nop_) \
27  { \
28  asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
29  "s_nop 3" \
30  : dmod_(c_vec) \
31  : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
32  :); \
33  } \
34  else \
35  { \
36  asm volatile(mfma_ " %0, %1, %2, %3\n" \
37  : dmod_(c_vec) \
38  : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
39  :); \
40  }
41 
42 #define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
43  if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
44  { \
45  DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
46  } \
47  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
48  { \
49  DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
50  } \
51  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
52  { \
53  DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
54  } \
55  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
56  { \
57  DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
58  } \
59  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
60  { \
61  DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
62  }
63 
64 // F32
65 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
67 {
68  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
69 
70  using ADataType = float;
71  using BDataType = float;
72  using CDataType = float;
73 
77 
78  static constexpr index_t kM = 16;
79  static constexpr index_t kN = 16;
80  static constexpr index_t kK = 4;
81 
82  static constexpr index_t kAMBlock = 1;
83  static constexpr index_t kBNBlock = 1;
84 
85  static constexpr index_t kAMLane = 16;
86  static constexpr index_t kBNLane = 16;
87  static constexpr index_t kABKLane = 4;
88  static constexpr index_t kABKPerLane = 1;
89 
90  static constexpr index_t kCMLane = 4;
91  static constexpr index_t kCNLane = 16;
92  static constexpr index_t kCM0PerLane = 1;
93  static constexpr index_t kCM1PerLane = 4;
94 
95  // c_vec += a_vec * b_vec
96  template <bool post_nop_ = false>
98  const AVecType& a_vec,
99  const BVecType& b_vec,
100  bool_constant<post_nop_> = {}) const
101  {
102  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x4f32", Ctrl)
103  else
104  {
105 #if defined(__gfx9__)
106  c_vec = __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0);
107 #else
108  ck_tile::ignore = c_vec;
109  ck_tile::ignore = a_vec;
110  ck_tile::ignore = b_vec;
111 #endif
112  }
113  }
114 
115  // c_vec = a_vec * b_vec
116  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
117  {
118 #if defined(__gfx9__)
119  return bit_cast<CVecType>(
120  __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0));
121 #else
122  ck_tile::ignore = a_vec;
123  ck_tile::ignore = b_vec;
124  return CVecType{0.f};
125 #endif
126  }
127 };
128 
129 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
131 {
132  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
133 
134  using ADataType = float;
135  using BDataType = float;
136  using CDataType = float;
137 
141 
142  static constexpr index_t kM = 32;
143  static constexpr index_t kN = 32;
144  static constexpr index_t kK = 2;
145 
146  static constexpr index_t kAMBlock = 1;
147  static constexpr index_t kBNBlock = 1;
148 
149  static constexpr index_t kAMLane = 32;
150  static constexpr index_t kBNLane = 32;
151  static constexpr index_t kABKLane = 2;
152  static constexpr index_t kABKPerLane = 1;
153 
154  static constexpr index_t kCMLane = 2;
155  static constexpr index_t kCNLane = 32;
156  static constexpr index_t kCM0PerLane = 4;
157  static constexpr index_t kCM1PerLane = 4;
158 
159  // c_vec += a_vec * b_vec
160  template <bool post_nop_ = false>
162  const AVecType& a_vec,
163  const BVecType& b_vec,
164  bool_constant<post_nop_> = {}) const
165  {
166  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x2f32", Ctrl)
167  else
168  {
169 #if defined(__gfx9__)
170  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0);
171 #else
172  ck_tile::ignore = c_vec;
173  ck_tile::ignore = a_vec;
174  ck_tile::ignore = b_vec;
175 #endif
176  }
177  }
178 
179  // c_vec = a_vec * b_vec
180  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
181  {
182 #if defined(__gfx9__)
183  return bit_cast<CVecType>(
184  __builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0));
185 #else
186  ck_tile::ignore = a_vec;
187  ck_tile::ignore = b_vec;
188  return CVecType{0.f};
189 #endif
190  }
191 };
192 
193 // V_MFMA_F32_16x16x32_BF16
194 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
196 {
197  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
198  using ADataType = bf16_t;
199  using BDataType = bf16_t;
200  using CDataType = float;
201 
205 
206  static constexpr index_t kM = 16;
207  static constexpr index_t kN = 16;
208  static constexpr index_t kK = 32;
209 
210  static constexpr index_t kAMBlock = 1;
211  static constexpr index_t kBNBlock = 1;
212 
213  static constexpr index_t kAMLane = 16;
214  static constexpr index_t kBNLane = 16;
215  static constexpr index_t kABKLane = 4;
216  static constexpr index_t kABKPerLane = 8;
217 
218  static constexpr index_t kCMLane = 4;
219  static constexpr index_t kCNLane = 16;
220  static constexpr index_t kCM0PerLane = 1;
221  static constexpr index_t kCM1PerLane = 4;
222 
223  // c_vec += a_vec * b_vec
224  template <bool post_nop_ = false>
226  const AVecType& a_vec,
227  const BVecType& b_vec,
228  bool_constant<post_nop_> = {}) const
229  {
230  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32_bf16", Ctrl)
231  else
232  {
233 #if defined(__gfx950__)
234  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, c_vec, 0, 0, 0);
235 #else
236  ck_tile::ignore = c_vec;
237  ck_tile::ignore = a_vec;
238  ck_tile::ignore = b_vec;
239 #endif
240  }
241  }
242 
243  // c_vec = a_vec * b_vec
244  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
245  {
246 #if defined(__gfx950__)
247  return bit_cast<CVecType>(
248  __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
249 #else
250  ck_tile::ignore = a_vec;
251  ck_tile::ignore = b_vec;
252  return CVecType{0.f};
253 #endif
254  }
255 };
256 // FP16
257 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
259 {
260  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
261  using ADataType = fp16_t;
262  using BDataType = fp16_t;
263  using CDataType = float;
264 
268 
269  static constexpr index_t kM = 32;
270  static constexpr index_t kN = 32;
271  static constexpr index_t kK = 8;
272 
273  static constexpr index_t kAMBlock = 1;
274  static constexpr index_t kBNBlock = 1;
275 
276  static constexpr index_t kAMLane = 32;
277  static constexpr index_t kBNLane = 32;
278  static constexpr index_t kABKLane = 2;
279  static constexpr index_t kABKPerLane = 4;
280 
281  static constexpr index_t kCMLane = 2;
282  static constexpr index_t kCNLane = 32;
283  static constexpr index_t kCM0PerLane = 4;
284  static constexpr index_t kCM1PerLane = 4;
285 
286  // c_vec += a_vec * b_vec
287  template <bool post_nop_ = false>
289  const AVecType& a_vec,
290  const BVecType& b_vec,
291  bool_constant<post_nop_> = {}) const
292  {
293  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
294  else
295  {
296 #if defined(__gfx9__)
297  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
298 #else
299  ck_tile::ignore = c_vec;
300  ck_tile::ignore = a_vec;
301  ck_tile::ignore = b_vec;
302 #endif
303  }
304  }
305 
306  // c_vec = a_vec * b_vec
307  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
308  {
309 #if defined(__gfx9__)
310  return bit_cast<CVecType>(
311  __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
312 #else
313  ck_tile::ignore = a_vec;
314  ck_tile::ignore = b_vec;
315  return CVecType{0.f};
316 #endif
317  }
318 };
319 
320 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
322 {
323  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
324  using ADataType = fp16_t;
325  using BDataType = fp16_t;
326  using CDataType = float;
327 
331 
332  static constexpr index_t kM = 16;
333  static constexpr index_t kN = 16;
334  static constexpr index_t kK = 16;
335 
336  static constexpr index_t kAMBlock = 1;
337  static constexpr index_t kBNBlock = 1;
338 
339  static constexpr index_t kAMLane = 16;
340  static constexpr index_t kBNLane = 16;
341  static constexpr index_t kABKLane = 4;
342  static constexpr index_t kABKPerLane = 4;
343 
344  static constexpr index_t kCMLane = 4;
345  static constexpr index_t kCNLane = 16;
346  static constexpr index_t kCM0PerLane = 1;
347  static constexpr index_t kCM1PerLane = 4;
348 
349  // c_vec += a_vec * b_vec
350  template <bool post_nop_ = false>
352  const AVecType& a_vec,
353  const BVecType& b_vec,
354  bool_constant<post_nop_> = {}) const
355  {
356  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
357  else
358  {
359 #if defined(__gfx9__)
360  c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
361 #else
362  ck_tile::ignore = c_vec;
363  ck_tile::ignore = a_vec;
364  ck_tile::ignore = b_vec;
365 #endif
366  }
367  }
368 
369  // c_vec = a_vec * b_vec
370  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
371  {
372 #if defined(__gfx9__)
373  return bit_cast<CVecType>(
374  __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
375 #else
376  ck_tile::ignore = a_vec;
377  ck_tile::ignore = b_vec;
378  return CVecType{0.f};
379 #endif
380  }
381 };
382 
383 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
385 {
386  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
387  using ADataType = fp16_t;
388  using BDataType = fp16_t;
389  using CDataType = float;
390 
394 
395  static constexpr index_t kM = 16;
396  static constexpr index_t kN = 16;
397  static constexpr index_t kK = 32;
398 
399  static constexpr index_t kAMBlock = 1;
400  static constexpr index_t kBNBlock = 1;
401 
402  static constexpr index_t kAMLane = 16;
403  static constexpr index_t kBNLane = 16;
404  static constexpr index_t kABKLane = 4;
405  static constexpr index_t kABKPerLane = 8;
406 
407  static constexpr index_t kCMLane = 4;
408  static constexpr index_t kCNLane = 16;
409  static constexpr index_t kCM0PerLane = 1;
410  static constexpr index_t kCM1PerLane = 4;
411 
412  // c_vec += a_vec * b_vec
413  template <bool post_nop_ = false>
415  const AVecType& a_vec,
416  const BVecType& b_vec,
417  bool_constant<post_nop_> = {}) const
418  {
419  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32f16", Ctrl)
420  else
421  {
422 #if defined(__gfx950__)
423  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, c_vec, 0, 0, 0);
424 #else
425  ck_tile::ignore = c_vec;
426  ck_tile::ignore = a_vec;
427  ck_tile::ignore = b_vec;
428 #endif
429  }
430  }
431 
432  // c_vec = a_vec * b_vec
433  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
434  {
435 #if defined(__gfx950__)
436  return bit_cast<CVecType>(
437  __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
438 #else
439  ck_tile::ignore = a_vec;
440  ck_tile::ignore = b_vec;
441  return CVecType{0.f};
442 #endif
443  }
444 };
445 
446 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
448 {
449  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
450  using ADataType = fp16_t;
451  using BDataType = fp16_t;
452  using CDataType = float;
453 
457 
458  static constexpr index_t kM = 4;
459  static constexpr index_t kN = 64;
460  static constexpr index_t kK = 4;
461 
462  static constexpr index_t kAMBlock = 1;
463  static constexpr index_t kBNBlock = 16;
464 
465  // we only write down single block (4 threads) thread mapping here
466  static constexpr index_t kAMLane = 4;
467  static constexpr index_t kBNLane = 4;
468  static constexpr index_t kABKLane = 1;
469  static constexpr index_t kABKPerLane = 4;
470 
471  static constexpr index_t kCMLane = 1;
472  static constexpr index_t kCNLane = 4;
473  static constexpr index_t kCM0PerLane = 1;
474  static constexpr index_t kCM1PerLane = 4;
475 
476  // c_vec += a_vec * b_vec
477  template <bool post_nop_ = false>
479  const AVecType& a_vec,
480  const BVecType& b_vec,
481  bool_constant<post_nop_> = {}) const
482  {
483  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
484  else
485  {
486 #if defined(__gfx9__)
487  c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
488 #else
489  ignore = c_vec;
490  ignore = a_vec;
491  ignore = b_vec;
492 #endif
493  }
494  }
495 
496  // c_vec = a_vec * b_vec
497  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
498  {
499 #if defined(__gfx9__)
500  return bit_cast<CVecType>(
501  __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
502 #else
503  ignore = a_vec;
504  ignore = b_vec;
505  return CVecType{0.f};
506 #endif
507  }
508 };
509 
510 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
512 {
513  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
514  using ADataType = fp16_t;
515  using BDataType = fp16_t;
516  using CDataType = float;
517 
521 
522  static constexpr index_t kM = 64;
523  static constexpr index_t kN = 4;
524  static constexpr index_t kK = 4;
525 
526  static constexpr index_t kAMBlock = 16;
527  static constexpr index_t kBNBlock = 1;
528 
529  // we only write down single block (4 threads) thread mapping here
530  static constexpr index_t kAMLane = 4;
531  static constexpr index_t kBNLane = 4;
532  static constexpr index_t kABKLane = 1;
533  static constexpr index_t kABKPerLane = 4;
534 
535  static constexpr index_t kCMLane = 1;
536  static constexpr index_t kCNLane = 4;
537  static constexpr index_t kCM0PerLane = 1;
538  static constexpr index_t kCM1PerLane = 4;
539 
540  // c_vec += a_vec * b_vec
541  template <bool post_nop_ = false>
543  const AVecType& a_vec,
544  const BVecType& b_vec,
545  bool_constant<post_nop_> = {}) const
546  {
547  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
548  else
549  {
550 #if defined(__gfx9__)
551  c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
552 #else
553  ignore = c_vec;
554  ignore = a_vec;
555  ignore = b_vec;
556 #endif
557  }
558  }
559 
560  // c_vec = a_vec * b_vec
561  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
562  {
563 #if defined(__gfx9__)
564  return bit_cast<CVecType>(
565  __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
566 #else
567  ignore = a_vec;
568  ignore = b_vec;
569  return CVecType{0.f};
570 #endif
571  }
572 };
573 
574 // Bf16
575 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
577 {
578  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
579  using ADataType = bf16_t;
580  using BDataType = bf16_t;
581  using CDataType = float;
582 
586 
587  static constexpr index_t kM = 32;
588  static constexpr index_t kN = 32;
589  static constexpr index_t kK = 8;
590 
591  static constexpr index_t kAMBlock = 1;
592  static constexpr index_t kBNBlock = 1;
593 
594  static constexpr index_t kAMLane = 32;
595  static constexpr index_t kBNLane = 32;
596  static constexpr index_t kABKLane = 2;
597  static constexpr index_t kABKPerLane = 4;
598 
599  static constexpr index_t kCMLane = 2;
600  static constexpr index_t kCNLane = 32;
601  static constexpr index_t kCM0PerLane = 4;
602  static constexpr index_t kCM1PerLane = 4;
603 
604  // c_vec += a_vec * b_vec
605  template <bool post_nop_ = false>
607  const AVecType& a_vec,
608  const BVecType& b_vec,
609  bool_constant<post_nop_> = {}) const
610  {
611  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
612  else
613  {
614 #if defined(__gfx90a__) || defined(__gfx94__)
615  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
616 #elif defined(__gfx908__)
617  static_for<0, 2, 1>{}([&](auto k) {
618  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
619  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
620  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
621  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
622  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
623  c_vec,
624  0,
625  0,
626  0);
627  });
628 #else
629  ck_tile::ignore = c_vec;
630  ck_tile::ignore = a_vec;
631  ck_tile::ignore = b_vec;
632 #endif
633  }
634  }
635 
636  // c_vec = a_vec * b_vec
637  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
638  {
639 #if defined(__gfx90a__) || defined(__gfx94__)
640  return bit_cast<CVecType>(
641  __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
642 #elif defined(__gfx908__)
643  CVecType c_vec{0.f};
644  static_for<0, 2, 1>{}([&](auto k) {
645  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
646  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
647  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
648  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
649  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
650  c_vec,
651  0,
652  0,
653  0);
654  });
655  return c_vec;
656 #else
657  ck_tile::ignore = a_vec;
658  ck_tile::ignore = b_vec;
659  return CVecType{0.f};
660 #endif
661  }
662 };
663 
664 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
666 {
667  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
668  using ADataType = bf16_t;
669  using BDataType = bf16_t;
670  using CDataType = float;
671 
675 
676  static constexpr index_t kM = 16;
677  static constexpr index_t kN = 16;
678  static constexpr index_t kK = 16;
679 
680  static constexpr index_t kAMBlock = 1;
681  static constexpr index_t kBNBlock = 1;
682 
683  static constexpr index_t kAMLane = 16;
684  static constexpr index_t kBNLane = 16;
685  static constexpr index_t kABKLane = 4;
686  static constexpr index_t kABKPerLane = 4;
687 
688  static constexpr index_t kCMLane = 4;
689  static constexpr index_t kCNLane = 16;
690  static constexpr index_t kCM0PerLane = 1;
691  static constexpr index_t kCM1PerLane = 4;
692 
693  // c_vec += a_vec * b_vec
694  template <bool post_nop_ = false>
696  const AVecType& a_vec,
697  const BVecType& b_vec,
698  bool_constant<post_nop_> = {}) const
699  {
700  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
701  {
702 #if defined(__gfx90a__) || defined(__gfx94__)
703  c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
704 #elif defined(__gfx908__)
705  static_for<0, 2, 1>{}([&](auto k) {
706  c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
707  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
708  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
709  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
710  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
711  c_vec,
712  0,
713  0,
714  0);
715  });
716 #else
717  ck_tile::ignore = c_vec;
718  ck_tile::ignore = a_vec;
719  ck_tile::ignore = b_vec;
720 #endif
721  }
722  }
723 
724  // c_vec = a_vec * b_vec
725  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
726  {
727 #if defined(__gfx90a__) || defined(__gfx94__)
728  return bit_cast<CVecType>(
729  __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
730 #elif defined(__gfx908__)
731  CVecType c_vec{0.f};
732  static_for<0, 2, 1>{}([&](auto k) {
733  c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
734  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
735  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
736  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
737  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
738  c_vec,
739  0,
740  0,
741  0);
742  });
743  return c_vec;
744 #else
745  ck_tile::ignore = a_vec;
746  ck_tile::ignore = b_vec;
747  return CVecType{0.f};
748 #endif
749  }
750 };
751 
752 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
754 {
755  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
756  using ADataType = bf16_t;
757  using BDataType = bf16_t;
758  using CDataType = float;
759 
763 
764  static constexpr index_t kM = 4;
765  static constexpr index_t kN = 64;
766  static constexpr index_t kK = 4;
767 
768  static constexpr index_t kAMBlock = 1;
769  static constexpr index_t kBNBlock = 16;
770 
771  // we only write down single block (4 threads) thread mapping here
772  static constexpr index_t kAMLane = 4;
773  static constexpr index_t kBNLane = 4;
774  static constexpr index_t kABKLane = 1;
775  static constexpr index_t kABKPerLane = 4;
776 
777  static constexpr index_t kCMLane = 1;
778  static constexpr index_t kCNLane = 4;
779  static constexpr index_t kCM0PerLane = 1;
780  static constexpr index_t kCM1PerLane = 4;
781 
782  // c_vec += a_vec * b_vec
783  template <bool post_nop_ = false>
785  const AVecType& a_vec,
786  const BVecType& b_vec,
787  bool_constant<post_nop_> = {}) const
788  {
789  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
790  else
791  {
792 #if defined(__gfx90a__) || defined(__gfx94__)
793  c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
794 #elif defined(__gfx908__)
795  static_for<0, 2, 1>{}([&](auto k) {
796  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
797  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
798  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
799  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
800  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
801  c_vec,
802  0,
803  0,
804  0);
805  });
806 #else
807  ignore = c_vec;
808  ignore = a_vec;
809  ignore = b_vec;
810 #endif
811  }
812  }
813 
814  // c_vec = a_vec * b_vec
815  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
816  {
817 #if defined(__gfx90a__) || defined(__gfx94__)
818  return bit_cast<CVecType>(
819  __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
820 #elif defined(__gfx908__)
821  CVecType c_vec{0.f};
822  static_for<0, 2, 1>{}([&](auto k) {
823  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
824  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
825  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
826  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
827  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
828  c_vec,
829  0,
830  0,
831  0);
832  });
833  return c_vec;
834 #else
835  ignore = a_vec;
836  ignore = b_vec;
837  return CVecType{0.f};
838 #endif
839  }
840 };
841 
842 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
844 {
845  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
846  using ADataType = bf16_t;
847  using BDataType = bf16_t;
848  using CDataType = float;
849 
853 
854  static constexpr index_t kM = 64;
855  static constexpr index_t kN = 4;
856  static constexpr index_t kK = 4;
857 
858  static constexpr index_t kAMBlock = 16;
859  static constexpr index_t kBNBlock = 1;
860 
861  // we only write down single block (4 threads) thread mapping here
862  static constexpr index_t kAMLane = 4;
863  static constexpr index_t kBNLane = 4;
864  static constexpr index_t kABKLane = 1;
865  static constexpr index_t kABKPerLane = 4;
866 
867  static constexpr index_t kCMLane = 1;
868  static constexpr index_t kCNLane = 4;
869  static constexpr index_t kCM0PerLane = 1;
870  static constexpr index_t kCM1PerLane = 4;
871 
872  // c_vec += a_vec * b_vec
873  template <bool post_nop_ = false>
875  const AVecType& a_vec,
876  const BVecType& b_vec,
877  bool_constant<post_nop_> = {}) const
878  {
879  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
880  else
881  {
882 #if defined(__gfx90a__) || defined(__gfx94__)
883  c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
884 #elif defined(__gfx908__)
885  static_for<0, 2, 1>{}([&](auto k) {
886  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
887  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
888  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
889  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
890  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
891  c_vec,
892  0,
893  0,
894  0);
895  });
896 #else
897  ignore = c_vec;
898  ignore = a_vec;
899  ignore = b_vec;
900 #endif
901  }
902  }
903 
904  // c_vec = a_vec * b_vec
905  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
906  {
907 #if defined(__gfx90a__) || defined(__gfx94__)
908  return bit_cast<CVecType>(
909  __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
910 #elif defined(__gfx908__)
911  CVecType c_vec{0.f};
912  static_for<0, 2, 1>{}([&](auto k) {
913  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
914  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
915  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
916  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
917  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
918  c_vec,
919  0,
920  0,
921  0);
922  });
923  return c_vec;
924 #else
925  ignore = a_vec;
926  ignore = b_vec;
927  return CVecType{0.f};
928 #endif
929  }
930 };
931 
932 // gfx950
933 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
935 {
936  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
937  using ADataType = fp16_t;
938  using BDataType = fp16_t;
939  using CDataType = float;
940 
944 
945  static constexpr index_t kM = 32;
946  static constexpr index_t kN = 32;
947  static constexpr index_t kK = 16;
948 
949  static constexpr index_t kAMBlock = 1;
950  static constexpr index_t kBNBlock = 1;
951 
952  static constexpr index_t kAMLane = 32;
953  static constexpr index_t kBNLane = 32;
954  static constexpr index_t kABKLane = 2;
955  static constexpr index_t kABKPerLane = 8;
956 
957  static constexpr index_t kCMLane = 2;
958  static constexpr index_t kCNLane = 32;
959  static constexpr index_t kCM0PerLane = 4;
960  static constexpr index_t kCM1PerLane = 4;
961 
962  // c_vec += a_vec * b_vec
963  template <bool post_nop_ = false>
965  const AVecType& a_vec,
966  const BVecType& b_vec,
967  bool_constant<post_nop_> = {}) const
968  {
969  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_f16", Ctrl)
970  else
971  {
972 #if defined(__gfx950__)
973  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, c_vec, 0, 0, 0);
974 #elif defined(__gfx90a__) || defined(__gfx94__)
975  static_for<0, 2, 1>{}([&](auto k) {
976  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
977  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
978  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
979  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
980  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
981  c_vec,
982  0,
983  0,
984  0);
985  });
986 #elif defined(__gfx908__)
987  static_for<0, 4, 1>{}([&](auto k) {
988  c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
989  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
990  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
991  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
992  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
993  c_vec,
994  0,
995  0,
996  0);
997  });
998 #else
999  ck_tile::ignore = c_vec;
1000  ck_tile::ignore = a_vec;
1001  ck_tile::ignore = b_vec;
1002 #endif
1003  }
1004  }
1005 
1006  // c_vec = a_vec * b_vec
1007  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1008  {
1009 #if defined(__gfx950__)
1010  return __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
1011 #elif defined(__gfx90a__) || defined(__gfx94__)
1012  CVecType c_vec{0.f};
1013  static_for<0, 2, 1>{}([&](auto k) {
1014  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
1015  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1016  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
1017  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1018  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
1019  c_vec,
1020  0,
1021  0,
1022  0);
1023  });
1024  return c_vec;
1025 #elif defined(__gfx908__)
1026  CVecType c_vec{0.f};
1027  static_for<0, 4, 1>{}([&](auto k) {
1028  c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
1029  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
1030  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
1031  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
1032  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
1033  c_vec,
1034  0,
1035  0,
1036  0);
1037  });
1038  return c_vec;
1039 #else
1040  ck_tile::ignore = a_vec;
1041  ck_tile::ignore = b_vec;
1042  return CVecType{0.f};
1043 #endif
1044  }
1045 };
1046 
1047 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1049 {
1050  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1053  using CDataType = float;
1054 
1058 
1059  static constexpr index_t kM = 32;
1060  static constexpr index_t kN = 32;
1061  static constexpr index_t kK = 16;
1062 
1063  static constexpr index_t kAMBlock = 1;
1064  static constexpr index_t kBNBlock = 1;
1065 
1066  static constexpr index_t kAMLane = 32;
1067  static constexpr index_t kBNLane = 32;
1068  static constexpr index_t kABKLane = 2;
1069  static constexpr index_t kABKPerLane = 8;
1070 
1071  static constexpr index_t kCMLane = 2;
1072  static constexpr index_t kCNLane = 32;
1073  static constexpr index_t kCM0PerLane = 4;
1074  static constexpr index_t kCM1PerLane = 4;
1075 
1076  // c_vec += a_vec * b_vec
1077  template <bool post_nop_ = false>
1079  const AVecType& a_vec,
1080  const BVecType& b_vec,
1081  bool_constant<post_nop_> = {}) const
1082  {
1083  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_bf16", Ctrl)
1084  else
1085  {
1086 #if defined(__gfx950__)
1087  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, c_vec, 0, 0, 0);
1088 #elif defined(__gfx90a__) || defined(__gfx94__)
1089  static_for<0, 2, 1>{}([&](auto k) {
1090  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
1091  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1092  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1093  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1094  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1095  c_vec,
1096  0,
1097  0,
1098  0);
1099  });
1100 #elif defined(__gfx908__)
1101  static_for<0, 4, 1>{}([&](auto k) {
1102  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
1103  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
1104  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1105  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
1106  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1107  c_vec,
1108  0,
1109  0,
1110  0);
1111  });
1112 #else
1113  ck_tile::ignore = c_vec;
1114  ck_tile::ignore = a_vec;
1115  ck_tile::ignore = b_vec;
1116 #endif
1117  }
1118  }
1119 
1120  // c_vec = a_vec * b_vec
1121  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1122  {
1123 #if defined(__gfx950__)
1124  return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
1125 #elif defined(__gfx90a__) || defined(__gfx94__)
1126  CVecType c_vec{0.f};
1127  static_for<0, 2, 1>{}([&](auto k) {
1128  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
1129  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1130  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1131  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1132  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1133  c_vec,
1134  0,
1135  0,
1136  0);
1137  });
1138  return c_vec;
1139 #elif defined(__gfx908__)
1140  CVecType c_vec{0.f};
1141  static_for<0, 4, 1>{}([&](auto k) {
1142  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
1143  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
1144  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1145  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
1146  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1147  c_vec,
1148  0,
1149  0,
1150  0);
1151  });
1152  return c_vec;
1153 #else
1154  ck_tile::ignore = a_vec;
1155  ck_tile::ignore = b_vec;
1156  return CVecType{0.f};
1157 #endif
1158  }
1159 };
1160 
1161 // FP8
1162 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1164 {
1165  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1166  using ADataType = AType_;
1167  using BDataType = BType_;
1168  using CDataType = float;
1169 
1173 
1174  static constexpr index_t kM = 16;
1175  static constexpr index_t kN = 16;
1176  static constexpr index_t kK = 32;
1177 
1178  static constexpr index_t kAMBlock = 1;
1179  static constexpr index_t kBNBlock = 1;
1180 
1181  static constexpr index_t kAMLane = 16;
1182  static constexpr index_t kBNLane = 16;
1183  static constexpr index_t kABKLane = 4;
1184  static constexpr index_t kABKPerLane = 8;
1185 
1186  static constexpr index_t kCMLane = 4;
1187  static constexpr index_t kCNLane = 16;
1188  static constexpr index_t kCM0PerLane = 1;
1189  static constexpr index_t kCM1PerLane = 4;
1190 
1191  // c_vec += a_vec * b_vec
1192  template <bool post_nop_ = false>
1194  const AVecType& a_vec,
1195  const BVecType& b_vec,
1196  bool_constant<post_nop_> = {}) const
1197  {
1198  if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
1199  {
1200  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1201  {
1202  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "v", "v")
1203  }
1204  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1205  {
1206  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "v", "v")
1207  }
1208  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1209  {
1210  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "v", "v")
1211  }
1212  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1213  {
1214  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "v", "v")
1215  }
1216  }
1217  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
1218  {
1219  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1220  {
1221  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "a", "v")
1222  }
1223  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1224  {
1225  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "a", "v")
1226  }
1227  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1228  {
1229  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "a", "v")
1230  }
1231  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1232  {
1233  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "a", "v")
1234  }
1235  }
1236  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
1237  {
1238  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1239  {
1240  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "v", "v")
1241  }
1242  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1243  {
1244  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "v", "v")
1245  }
1246  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1247  {
1248  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "v", "v")
1249  }
1250  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1251  {
1252  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "v", "v")
1253  }
1254  }
1255  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
1256  {
1257  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1258  {
1259  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "a", "v")
1260  }
1261  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1262  {
1263  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "a", "v")
1264  }
1265  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1266  {
1267  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "a", "v")
1268  }
1269  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1270  {
1271  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "a", "v")
1272  }
1273  }
1274  else
1275  {
1276 #if defined(__gfx94__) or defined(__gfx95__)
1277  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1278  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1279  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1280  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1281  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1282  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1283  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1284  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1285  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1286  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1287  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1288  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1289 #else
1290  ck_tile::ignore = c_vec;
1291  ck_tile::ignore = a_vec;
1292  ck_tile::ignore = b_vec;
1293 #endif
1294  }
1295  }
1296 
1297  // c_vec = a_vec * b_vec
1298  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1299  {
1300 #if defined(__gfx94__) or defined(__gfx95__)
1301  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1302  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1303  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1304  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1305  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1306  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1307  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1308  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1309  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1310  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1311  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1312  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1313 #else
1314  ck_tile::ignore = a_vec;
1315  ck_tile::ignore = b_vec;
1316  return CVecType{0.f};
1317 #endif
1318  }
1319 };
1320 
1321 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1323 {
1324  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1325  using ADataType = AType_;
1326  using BDataType = BType_;
1327  using CDataType = float;
1328 
1332 
1333  static constexpr index_t kM = 32;
1334  static constexpr index_t kN = 32;
1335  static constexpr index_t kK = 16;
1336 
1337  static constexpr index_t kAMBlock = 1;
1338  static constexpr index_t kBNBlock = 1;
1339 
1340  static constexpr index_t kAMLane = 32;
1341  static constexpr index_t kBNLane = 32;
1342  static constexpr index_t kABKLane = 2;
1343  static constexpr index_t kABKPerLane = 8;
1344 
1345  static constexpr index_t kCMLane = 2;
1346  static constexpr index_t kCNLane = 32;
1347  static constexpr index_t kCM0PerLane = 4;
1348  static constexpr index_t kCM1PerLane = 4;
1349 
1350  // c_vec += a_vec * b_vec
1351  template <bool post_nop_ = false>
1353  const AVecType& a_vec,
1354  const BVecType& b_vec,
1355  bool_constant<post_nop_> = {}) const
1356  {
1357  if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
1358  {
1359  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1360  {
1361  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "v", "v")
1362  }
1363  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1364  {
1365  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "v", "v")
1366  }
1367  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1368  {
1369  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "v", "v")
1370  }
1371  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1372  {
1373  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "v", "v")
1374  }
1375  }
1376  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
1377  {
1378  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1379  {
1380  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "a", "v")
1381  }
1382  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1383  {
1384  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "a", "v")
1385  }
1386  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1387  {
1388  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "a", "v")
1389  }
1390  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1391  {
1392  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "a", "v")
1393  }
1394  }
1395  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
1396  {
1397  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1398  {
1399  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "v", "v")
1400  }
1401  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1402  {
1403  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "v", "v")
1404  }
1405  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1406  {
1407  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "v", "v")
1408  }
1409  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1410  {
1411  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "v", "v")
1412  }
1413  }
1414  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
1415  {
1416  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1417  {
1418  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "a", "v")
1419  }
1420  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1421  {
1422  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "a", "v")
1423  }
1424  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1425  {
1426  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "a", "v")
1427  }
1428  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1429  {
1430  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "a", "v")
1431  }
1432  }
1433  else
1434  {
1435 #if defined(__gfx94__) or defined(__gfx95__)
1436  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1437  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1438  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1439  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1440  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1441  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1442  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1443  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1444  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1445  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1446  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1447  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1448 #elif defined(__gfx908__) || defined(__gfx90a__)
1449  static_for<0, 8, 1>{}([&](auto k) {
1450  float a_f32 =
1451  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1452  .template get_as<ADataType>()[number<k>{}]);
1453  float b_f32 =
1454  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1455  .template get_as<BDataType>()[number<k>{}]);
1456 
1457  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1458  });
1459 #else
1460  ck_tile::ignore = c_vec;
1461  ck_tile::ignore = a_vec;
1462  ck_tile::ignore = b_vec;
1463 #endif
1464  }
1465  }
1466 
1467  // c_vec = a_vec * b_vec
1468  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1469  {
1470 #if defined(__gfx94__) or defined(__gfx95__)
1471  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1472  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1473  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1474  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1475  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1476  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1477  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1478  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1479  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1480  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1481  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1482  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1483 #elif defined(__gfx908__) || defined(__gfx90a__)
1484  CVecType c_vec{0.f};
1485  static_for<0, 8, 1>{}([&](auto k) {
1486  float a_f32 =
1487  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1488  .template get_as<ADataType>()[number<k>{}]);
1489  float b_f32 =
1490  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1491  .template get_as<BDataType>()[number<k>{}]);
1492 
1493  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1494  });
1495  return c_vec;
1496 #else
1497  ck_tile::ignore = a_vec;
1498  ck_tile::ignore = b_vec;
1499  return CVecType{0.f};
1500 #endif
1501  }
1502 };
1503 
1504 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1507 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1510 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1513 
1514 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1517 
1518 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1521 
1522 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1525 
1526 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1528 {
1529  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1530  using ADataType = AType_;
1531  using BDataType = BType_;
1532  using CDataType = float;
1533 
1537 
1538  static constexpr index_t kM = 16;
1539  static constexpr index_t kN = 16;
1540  static constexpr index_t kK = 128;
1541 
1542  static constexpr index_t kAMBlock = 1;
1543  static constexpr index_t kBNBlock = 1;
1544 
1545  static constexpr index_t kAMLane = 16;
1546  static constexpr index_t kBNLane = 16;
1547  static constexpr index_t kABKLane = 4;
1548  static constexpr index_t kABKPerLane = 32;
1549 
1550  static constexpr index_t kCMLane = 4;
1551  static constexpr index_t kCNLane = 16;
1552  static constexpr index_t kCM0PerLane = 1;
1553  static constexpr index_t kCM1PerLane = 4;
1554 
1555  // c_vec += a_vec * b_vec
1556  template <bool post_nop_ = false>
1558  const AVecType& a_vec,
1559  const BVecType& b_vec,
1560  bool_constant<post_nop_> = {}) const
1561  {
1562  //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
1563  // opsel, scale_b)
1564 #if defined(__gfx950__)
1565  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1566  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1567  a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
1568  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1569  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1570  a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
1571  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1572  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1573  a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
1574  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1575  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1576  a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
1577 #else
1578  ck_tile::ignore = c_vec;
1579  ck_tile::ignore = a_vec;
1580  ck_tile::ignore = b_vec;
1581 #endif
1582  }
1583 
1584  // c_vec = a_vec * b_vec
1585  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1586  {
1587 #if defined(__gfx950__)
1588  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1589  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1590  a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
1591  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1592  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1593  a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
1594  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1595  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1596  a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
1597  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1598  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1599  a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
1600 #else
1601  ck_tile::ignore = a_vec;
1602  ck_tile::ignore = b_vec;
1603  return CVecType{0.f};
1604 #endif
1605  }
1606 };
1607 
1608 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1611 
1612 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1615 
1616 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1619 
1620 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1623 
1624 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1626 {
1627  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1628  using ADataType = AType_;
1629  using BDataType = BType_;
1630  using CDataType = float;
1631 
1635 
1636  static constexpr index_t kM = 32;
1637  static constexpr index_t kN = 32;
1638  static constexpr index_t kK = 64;
1639 
1640  static constexpr index_t kAMBlock = 1;
1641  static constexpr index_t kBNBlock = 1;
1642 
1643  static constexpr index_t kAMLane = 32;
1644  static constexpr index_t kBNLane = 32;
1645  static constexpr index_t kABKLane = 2;
1646  static constexpr index_t kABKPerLane = 32;
1647 
1648  static constexpr index_t kCMLane = 2;
1649  static constexpr index_t kCNLane = 32;
1650  static constexpr index_t kCM0PerLane = 4;
1651  static constexpr index_t kCM1PerLane = 4;
1652 
1653  // c_vec += a_vec * b_vec
1654  template <bool post_nop_ = false>
1656  const AVecType& a_vec,
1657  const BVecType& b_vec,
1658  bool_constant<post_nop_> = {}) const
1659  {
1660  //__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
1661  // opsel, scale_b)
1662 #if defined(__gfx950__)
1663  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1664  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1665  a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
1666  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1667  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1668  a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
1669  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1670  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1671  a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
1672  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1673  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1674  a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
1675 #else
1676  ck_tile::ignore = c_vec;
1677  ck_tile::ignore = a_vec;
1678  ck_tile::ignore = b_vec;
1679 #endif
1680  }
1681 
1682  // c_vec = a_vec * b_vec
1683  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1684  {
1685 #if defined(__gfx950__)
1686  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1687  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1688  a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
1689  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1690  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1691  a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
1692  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1693  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1694  a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
1695  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1696  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1697  a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
1698 #else
1699  ck_tile::ignore = a_vec;
1700  ck_tile::ignore = b_vec;
1701  return CVecType{0.f};
1702 #endif
1703  }
1704 };
1705 
1706 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1709 
1710 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1713 
1714 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1717 
1718 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1721 
1722 // int8
1723 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1725 {
1726  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1730 
1734 
1735  static constexpr index_t kM = 32;
1736  static constexpr index_t kN = 32;
1737  static constexpr index_t kK = 16;
1738 
1739  static constexpr index_t kAMBlock = 1;
1740  static constexpr index_t kBNBlock = 1;
1741 
1742  static constexpr index_t kAMLane = 32;
1743  static constexpr index_t kBNLane = 32;
1744  static constexpr index_t kABKLane = 2;
1745  static constexpr index_t kABKPerLane = 8;
1746 
1747  static constexpr index_t kCMLane = 2;
1748  static constexpr index_t kCNLane = 32;
1749  static constexpr index_t kCM0PerLane = 4;
1750  static constexpr index_t kCM1PerLane = 4;
1751 
1752  // c_vec += a_vec * b_vec
1753  template <bool post_nop_ = false>
1755  const AVecType& a_vec,
1756  const BVecType& b_vec,
1757  bool_constant<post_nop_> = {}) const
1758  {
1759  DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
1760  else
1761  {
1762 #if defined(__gfx94__) or defined(__gfx95__)
1763  c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
1764  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1765 #elif defined(__gfx908__) || defined(__gfx90a__)
1766  static_for<0, 8, 1>{}([&](auto k) {
1767  float a_f32 =
1768  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1769  .template get_as<ADataType>()[number<k>{}]);
1770  float b_f32 =
1771  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1772  .template get_as<BDataType>()[number<k>{}]);
1773 
1774  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1775  });
1776 #else
1777  ck_tile::ignore = c_vec;
1778  ck_tile::ignore = a_vec;
1779  ck_tile::ignore = b_vec;
1780 #endif
1781  }
1782  }
1783 
1784  // c_vec = a_vec * b_vec
1785  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1786  {
1787  CVecType c_vec{0};
1788  operator()(c_vec, a_vec, b_vec);
1789  return c_vec;
1790  }
1791 };
1792 
1793 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1795 {
1796  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1800 
1804 
1805  static constexpr index_t kM = 16;
1806  static constexpr index_t kN = 16;
1807  static constexpr index_t kK = 32;
1808 
1809  static constexpr index_t kAMBlock = 1;
1810  static constexpr index_t kBNBlock = 1;
1811 
1812  static constexpr index_t kAMLane = 16;
1813  static constexpr index_t kBNLane = 16;
1814  static constexpr index_t kABKLane = 4;
1815  static constexpr index_t kABKPerLane = 8;
1816 
1817  static constexpr index_t kCMLane = 4;
1818  static constexpr index_t kCNLane = 16;
1819  static constexpr index_t kCM0PerLane = 1;
1820  static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
1821 
1822  // c_vec += a_vec * b_vec
1823  template <bool post_nop_ = false>
1825  const AVecType& a_vec,
1826  const BVecType& b_vec,
1827  bool_constant<post_nop_> = {}) const
1828  {
1829  DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x32_i8", Ctrl)
1830  else
1831  {
1832 #if defined(__gfx94__) or defined(__gfx95__)
1833  c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
1834  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1835 #else
1836  ck_tile::ignore = c_vec;
1837  ck_tile::ignore = a_vec;
1838  ck_tile::ignore = b_vec;
1839 #endif
1840  }
1841  }
1842 
1843  // c_vec = a_vec * b_vec
1844  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1845  {
1846  CVecType c_vec{0};
1847  operator()(c_vec, a_vec, b_vec);
1848  return c_vec;
1849  }
1850 };
1851 
1852 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1854 {
1855  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1859 
1863 
1864  static constexpr index_t kM = 16;
1865  static constexpr index_t kN = 16;
1866  static constexpr index_t kK = 64;
1867 
1868  static constexpr index_t kAMBlock = 1;
1869  static constexpr index_t kBNBlock = 1;
1870 
1871  static constexpr index_t kAMLane = 16;
1872  static constexpr index_t kBNLane = 16;
1873  static constexpr index_t kABKLane = 4;
1874  static constexpr index_t kABKPerLane = 16;
1875 
1876  static constexpr index_t kCMLane = 4;
1877  static constexpr index_t kCNLane = 16;
1878  static constexpr index_t kCM0PerLane = 1;
1879  static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
1880 
1881  // c_vec += a_vec * b_vec
1882  template <bool post_nop_ = false>
1884  const AVecType& a_vec,
1885  const BVecType& b_vec,
1886  bool_constant<post_nop_> = {}) const
1887  {
1888  DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x64_i8", Ctrl)
1889  else
1890  {
1891 #if defined(__gfx95__)
1892  c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
1893  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1894 #else
1895  ck_tile::ignore = c_vec;
1896  ck_tile::ignore = a_vec;
1897  ck_tile::ignore = b_vec;
1898 #endif
1899  }
1900  }
1901 
1902  // c_vec = a_vec * b_vec
1903  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1904  {
1905  CVecType c_vec{0};
1906  operator()(c_vec, a_vec, b_vec);
1907  return c_vec;
1908  }
1909 };
1910 
1911 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1913 {
1914  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1918 
1922 
1923  static constexpr index_t kM = 32;
1924  static constexpr index_t kN = 32;
1925  static constexpr index_t kK = 32;
1926 
1927  static constexpr index_t kAMBlock = 1;
1928  static constexpr index_t kBNBlock = 1;
1929 
1930  static constexpr index_t kAMLane = 32;
1931  static constexpr index_t kBNLane = 32;
1932  static constexpr index_t kABKLane = 2;
1933  static constexpr index_t kABKPerLane = 16;
1934 
1935  static constexpr index_t kCMLane = 2;
1936  static constexpr index_t kCNLane = 32;
1937  static constexpr index_t kCM0PerLane = 4;
1938  static constexpr index_t kCM1PerLane = 4;
1939 
1940  // c_vec += a_vec * b_vec
1941  template <bool post_nop_ = false>
1943  const AVecType& a_vec,
1944  const BVecType& b_vec,
1945  bool_constant<post_nop_> = {}) const
1946  {
1947  DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x32_i8", Ctrl)
1948  else
1949  {
1950 #if defined(__gfx95__)
1951  c_vec = __builtin_amdgcn_mfma_i32_32x32x32_i8(
1952  a_vec, bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1953 #else
1954  ck_tile::ignore = c_vec;
1955  ck_tile::ignore = a_vec;
1956  ck_tile::ignore = b_vec;
1957 #endif
1958  }
1959  }
1960 
1961  // c_vec = a_vec * b_vec
1962  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1963  {
1964  CVecType c_vec{0};
1965  operator()(c_vec, a_vec, b_vec);
1966  return c_vec;
1967  }
1968 };
1969 
1970 #undef DISPATCH_MFMA_
1971 
1972 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
WGAttrCtlEnum
Definition: warp_gemm_attribute_mfma_impl.hpp:15
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:83
int32_t int32_t
Definition: integer.hpp:10
float fp32x16_t
Definition: vector_type.hpp:119
float fp32x4_t
Definition: vector_type.hpp:117
Definition: warp_gemm_attribute_mfma_impl.hpp:1528
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1538
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1552
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1543
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1529
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1536
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1547
ext_vector_t< ADataType, 32 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1534
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1550
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1532
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1545
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1531
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1553
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1540
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1548
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1542
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1551
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1546
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1539
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1557
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1585
ext_vector_t< BDataType, 32 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1535
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1530
Definition: warp_gemm_attribute_mfma_impl.hpp:1164
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1178
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1179
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1188
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1174
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1166
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1172
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1170
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1165
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1186
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1183
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1193
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1189
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1168
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1298
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1182
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1167
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1187
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1184
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1176
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1181
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1175
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1171
Definition: warp_gemm_attribute_mfma_impl.hpp:1323
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1345
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1324
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1329
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1326
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1348
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1331
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1343
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1327
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1337
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1346
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1352
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1468
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1340
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1333
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1325
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1330
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1335
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1334
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1347
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1338
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1342
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1341
Definition: warp_gemm_attribute_mfma_impl.hpp:1626
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1683
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1637
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1648
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1629
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1646
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1636
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1644
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1643
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1634
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1630
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1655
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1650
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1641
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1640
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1649
ext_vector_t< BDataType, 32 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1633
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1627
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1628
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1645
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1651
ext_vector_t< ADataType, 32 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1632
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1638
Definition: warp_gemm_attribute_mfma_impl.hpp:1795
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1844
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1803
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1802
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1801
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1812
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1824
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1796
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1806
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1810
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1797
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1813
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1815
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1814
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1809
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1817
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1819
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1805
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1820
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1818
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1799
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1807
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1798
Definition: warp_gemm_attribute_mfma_impl.hpp:1854
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1858
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1879
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1866
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1865
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1872
ext_vector_t< ADataType, 16 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1860
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1874
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1878
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1864
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1871
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1883
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1877
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1869
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1873
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1855
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1857
ext_vector_t< BDataType, 16 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1861
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1903
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1856
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1876
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1862
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1868
Definition: warp_gemm_attribute_mfma_impl.hpp:1725
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1727
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1754
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1747
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1731
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1732
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1749
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1729
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1726
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1785
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1728
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1742
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1736
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1740
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1737
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1750
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1733
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1748
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1745
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1739
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1743
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1744
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1735
Definition: warp_gemm_attribute_mfma_impl.hpp:1913
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1923
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1915
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1937
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1932
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1942
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1916
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1936
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1917
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1931
ext_vector_t< BDataType, 16 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1920
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1928
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1930
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1935
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1925
ext_vector_t< ADataType, 16 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1919
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1921
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1962
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1924
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1938
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1933
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1914
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1927
Definition: warp_gemm_attribute_mfma_impl.hpp:666
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:674
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:684
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:673
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:695
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:688
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:725
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:678
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:683
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:667
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:689
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:690
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:685
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:680
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:670
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:691
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:677
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:669
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:681
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:668
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:672
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:686
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:676
Definition: warp_gemm_attribute_mfma_impl.hpp:196
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:219
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:199
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:204
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:207
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:210
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:220
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:215
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:206
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:208
ext_vector_t< bf16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:202
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:197
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:216
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:200
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:214
ext_vector_t< bf16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:203
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:218
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:198
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:225
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:221
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:244
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:211
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:213
Definition: warp_gemm_attribute_mfma_impl.hpp:1049
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1072
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1078
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1053
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1050
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1052
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1063
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1064
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1067
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1074
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1121
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1068
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1073
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1066
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1059
ext_vector_t< bf16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1055
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1060
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1071
ext_vector_t< bf16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1056
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1051
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1069
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1057
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1061
Definition: warp_gemm_attribute_mfma_impl.hpp:577
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:584
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:580
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:581
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:592
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:585
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:583
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:637
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:588
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:601
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:587
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:600
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:591
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:596
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:595
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:594
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:589
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:599
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:579
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:578
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:606
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:602
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:597
Definition: warp_gemm_attribute_mfma_impl.hpp:754
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:775
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:756
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:778
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:777
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:779
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:815
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:757
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:766
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:755
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:773
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:784
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:760
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:765
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:764
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:774
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:769
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:761
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:780
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:762
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:772
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:768
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:758
Definition: warp_gemm_attribute_mfma_impl.hpp:844
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:847
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:867
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:862
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:905
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:845
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:868
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:855
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:846
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:858
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:869
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:859
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:863
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:856
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:848
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:864
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:870
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:865
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:850
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:874
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:851
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:854
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:852
Definition: warp_gemm_attribute_mfma_impl.hpp:322
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:323
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:339
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:326
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:345
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:336
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:325
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:344
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:347
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:328
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:337
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:333
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:324
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:341
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:340
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:330
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:346
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:342
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:370
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:351
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:332
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:334
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:329
Definition: warp_gemm_attribute_mfma_impl.hpp:385
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:396
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:408
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:410
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:387
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:407
ext_vector_t< fp16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:391
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:404
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:405
ext_vector_t< fp16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:392
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:395
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:393
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:433
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:388
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:389
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:402
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:409
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:400
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:399
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:386
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:403
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:397
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:414
Definition: warp_gemm_attribute_mfma_impl.hpp:935
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:936
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:937
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:949
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:957
ext_vector_t< fp16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:941
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:958
ext_vector_t< fp16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:942
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:945
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:939
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:964
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:946
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:955
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1007
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:954
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:953
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:960
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:950
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:947
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:952
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:959
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:943
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:938
Definition: warp_gemm_attribute_mfma_impl.hpp:259
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:282
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:276
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:283
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:281
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:288
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:284
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:263
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:271
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:262
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:277
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:267
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:274
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:260
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:269
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:266
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:270
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:307
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:273
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:265
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:261
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:278
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:279
Definition: warp_gemm_attribute_mfma_impl.hpp:448
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:459
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:471
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:450
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:454
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:497
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:455
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:463
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:452
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:478
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:456
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:466
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:467
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:472
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:451
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:460
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:468
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:458
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:473
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:449
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:469
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:474
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:462
Definition: warp_gemm_attribute_mfma_impl.hpp:512
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:533
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:516
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:522
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:530
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:515
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:513
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:527
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:542
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:524
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:526
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:520
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:523
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:561
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:535
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:518
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:531
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:536
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:537
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:532
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:514
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:519
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:538
Definition: warp_gemm_attribute_mfma_impl.hpp:67
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:68
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:87
ext_vector_t< ADataType, 1 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:74
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:80
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:78
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:97
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:90
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:93
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:85
float ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:70
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:76
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:83
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:82
float BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:71
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:92
ext_vector_t< BDataType, 1 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:75
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:88
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:72
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:116
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:91
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:79
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:86
Definition: warp_gemm_attribute_mfma_impl.hpp:131
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:156
float BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:135
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:142
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:157
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:155
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:132
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:161
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:136
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:149
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:140
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:154
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:150
ext_vector_t< ADataType, 1 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:138
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:180
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:152
ext_vector_t< BDataType, 1 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:139
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:143
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:146
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:151
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:147
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:144
float ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:134
Definition: integral_constant.hpp:13
Definition: functional.hpp:43
Definition: debug.hpp:67
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_)
Definition: warp_gemm_attribute_mfma_impl.hpp:25
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_)
Definition: warp_gemm_attribute_mfma_impl.hpp:42