/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File
warp_gemm_attribute_mfma_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 // TODO: refactor warp-gemm
11 // currently there is a discrepency for vav/vva if we need transpose C/D
12 // e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
13 // because we swap the A/B pointer in _impl code (but not known this info here)
14 enum class WGAttrCtlEnum
15 {
16  Default_ = 0,
17  Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
18  Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
19  Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
20  Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
21  Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
22  // raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
23 };
24 
25 #define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
26  if constexpr(post_nop_) \
27  { \
28  asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
29  "s_nop 3" \
30  : dmod_(c_vec) \
31  : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
32  :); \
33  } \
34  else \
35  { \
36  asm volatile(mfma_ " %0, %1, %2, %3\n" \
37  : dmod_(c_vec) \
38  : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
39  :); \
40  }
41 
42 #define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
43  if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
44  { \
45  DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
46  } \
47  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
48  { \
49  DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
50  } \
51  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
52  { \
53  DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
54  } \
55  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
56  { \
57  DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
58  } \
59  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
60  { \
61  DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
62  }
63 
64 // F32
65 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
67 {
68  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
69 
70  using ADataType = float;
71  using BDataType = float;
72  using CDataType = float;
73 
77 
78  static constexpr index_t kM = 16;
79  static constexpr index_t kN = 16;
80  static constexpr index_t kK = 4;
81 
82  static constexpr index_t kAMBlock = 1;
83  static constexpr index_t kBNBlock = 1;
84 
85  static constexpr index_t kAMLane = 16;
86  static constexpr index_t kBNLane = 16;
87  static constexpr index_t kABKLane = 4;
88  static constexpr index_t kABKPerLane = 1;
89 
90  static constexpr index_t kCMLane = 4;
91  static constexpr index_t kCNLane = 16;
92  static constexpr index_t kCM0PerLane = 1;
93  static constexpr index_t kCM1PerLane = 4;
94 
95  // c_vec += a_vec * b_vec
96  template <bool post_nop_ = false>
98  const AVecType& a_vec,
99  const BVecType& b_vec,
100  bool_constant<post_nop_> = {}) const
101  {
102  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x4f32", Ctrl)
103  else
104  {
105 #if defined(__gfx9__)
106  c_vec = __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0);
107 #else
108  ck_tile::ignore = c_vec;
109  ck_tile::ignore = a_vec;
110  ck_tile::ignore = b_vec;
111 #endif
112  }
113  }
114 
115  // c_vec = a_vec * b_vec
116  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
117  {
118 #if defined(__gfx9__)
119  return bit_cast<CVecType>(
120  __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0));
121 #else
122  ck_tile::ignore = a_vec;
123  ck_tile::ignore = b_vec;
124  return CVecType{0.f};
125 #endif
126  }
127 };
128 
129 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
131 {
132  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
133 
134  using ADataType = float;
135  using BDataType = float;
136  using CDataType = float;
137 
141 
142  static constexpr index_t kM = 32;
143  static constexpr index_t kN = 32;
144  static constexpr index_t kK = 2;
145 
146  static constexpr index_t kAMBlock = 1;
147  static constexpr index_t kBNBlock = 1;
148 
149  static constexpr index_t kAMLane = 32;
150  static constexpr index_t kBNLane = 32;
151  static constexpr index_t kABKLane = 2;
152  static constexpr index_t kABKPerLane = 1;
153 
154  static constexpr index_t kCMLane = 2;
155  static constexpr index_t kCNLane = 32;
156  static constexpr index_t kCM0PerLane = 4;
157  static constexpr index_t kCM1PerLane = 4;
158 
159  // c_vec += a_vec * b_vec
160  template <bool post_nop_ = false>
162  const AVecType& a_vec,
163  const BVecType& b_vec,
164  bool_constant<post_nop_> = {}) const
165  {
166  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x2f32", Ctrl)
167  else
168  {
169 #if defined(__gfx9__)
170  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0);
171 #else
172  ck_tile::ignore = c_vec;
173  ck_tile::ignore = a_vec;
174  ck_tile::ignore = b_vec;
175 #endif
176  }
177  }
178 
179  // c_vec = a_vec * b_vec
180  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
181  {
182 #if defined(__gfx9__)
183  return bit_cast<CVecType>(
184  __builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0));
185 #else
186  ck_tile::ignore = a_vec;
187  ck_tile::ignore = b_vec;
188  return CVecType{0.f};
189 #endif
190  }
191 };
192 
193 // V_MFMA_F32_16x16x32_BF16
194 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
196 {
197  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
198  using ADataType = bf16_t;
199  using BDataType = bf16_t;
200  using CDataType = float;
201 
205 
206  static constexpr index_t kM = 16;
207  static constexpr index_t kN = 16;
208  static constexpr index_t kK = 32;
209 
210  static constexpr index_t kAMBlock = 1;
211  static constexpr index_t kBNBlock = 1;
212 
213  static constexpr index_t kAMLane = 16;
214  static constexpr index_t kBNLane = 16;
215  static constexpr index_t kABKLane = 4;
216  static constexpr index_t kABKPerLane = 8;
217 
218  static constexpr index_t kCMLane = 4;
219  static constexpr index_t kCNLane = 16;
220  static constexpr index_t kCM0PerLane = 1;
221  static constexpr index_t kCM1PerLane = 4;
222 
223  // c_vec += a_vec * b_vec
224  template <bool post_nop_ = false>
226  const AVecType& a_vec,
227  const BVecType& b_vec,
228  bool_constant<post_nop_> = {}) const
229  {
230  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32_bf16", Ctrl)
231  else
232  {
233 #if defined(__gfx950__)
234  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, c_vec, 0, 0, 0);
235 #else
236  ck_tile::ignore = c_vec;
237  ck_tile::ignore = a_vec;
238  ck_tile::ignore = b_vec;
239 #endif
240  }
241  }
242 
243  // c_vec = a_vec * b_vec
244  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
245  {
246 #if defined(__gfx950__)
247  return bit_cast<CVecType>(
248  __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
249 #else
250  ck_tile::ignore = a_vec;
251  ck_tile::ignore = b_vec;
252  return CVecType{0.f};
253 #endif
254  }
255 };
256 // FP16
257 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
259 {
260  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
261  using ADataType = fp16_t;
262  using BDataType = fp16_t;
263  using CDataType = float;
264 
268 
269  static constexpr index_t kM = 32;
270  static constexpr index_t kN = 32;
271  static constexpr index_t kK = 8;
272 
273  static constexpr index_t kAMBlock = 1;
274  static constexpr index_t kBNBlock = 1;
275 
276  static constexpr index_t kAMLane = 32;
277  static constexpr index_t kBNLane = 32;
278  static constexpr index_t kABKLane = 2;
279  static constexpr index_t kABKPerLane = 4;
280 
281  static constexpr index_t kCMLane = 2;
282  static constexpr index_t kCNLane = 32;
283  static constexpr index_t kCM0PerLane = 4;
284  static constexpr index_t kCM1PerLane = 4;
285 
286  // c_vec += a_vec * b_vec
287  template <bool post_nop_ = false>
289  const AVecType& a_vec,
290  const BVecType& b_vec,
291  bool_constant<post_nop_> = {}) const
292  {
293  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
294  else
295  {
296 #if defined(__gfx9__)
297  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
298 #else
299  ck_tile::ignore = c_vec;
300  ck_tile::ignore = a_vec;
301  ck_tile::ignore = b_vec;
302 #endif
303  }
304  }
305 
306  // c_vec = a_vec * b_vec
307  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
308  {
309 #if defined(__gfx9__)
310  return bit_cast<CVecType>(
311  __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
312 #else
313  ck_tile::ignore = a_vec;
314  ck_tile::ignore = b_vec;
315  return CVecType{0.f};
316 #endif
317  }
318 };
319 
320 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
322 {
323  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
324  using ADataType = fp16_t;
325  using BDataType = fp16_t;
326  using CDataType = float;
327 
331 
332  static constexpr index_t kM = 16;
333  static constexpr index_t kN = 16;
334  static constexpr index_t kK = 16;
335 
336  static constexpr index_t kAMBlock = 1;
337  static constexpr index_t kBNBlock = 1;
338 
339  static constexpr index_t kAMLane = 16;
340  static constexpr index_t kBNLane = 16;
341  static constexpr index_t kABKLane = 4;
342  static constexpr index_t kABKPerLane = 4;
343 
344  static constexpr index_t kCMLane = 4;
345  static constexpr index_t kCNLane = 16;
346  static constexpr index_t kCM0PerLane = 1;
347  static constexpr index_t kCM1PerLane = 4;
348 
349  // c_vec += a_vec * b_vec
350  template <bool post_nop_ = false>
352  const AVecType& a_vec,
353  const BVecType& b_vec,
354  bool_constant<post_nop_> = {}) const
355  {
356  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
357  else
358  {
359 #if defined(__gfx9__)
360  c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
361 #else
362  ck_tile::ignore = c_vec;
363  ck_tile::ignore = a_vec;
364  ck_tile::ignore = b_vec;
365 #endif
366  }
367  }
368 
369  // c_vec = a_vec * b_vec
370  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
371  {
372 #if defined(__gfx9__)
373  return bit_cast<CVecType>(
374  __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
375 #else
376  ck_tile::ignore = a_vec;
377  ck_tile::ignore = b_vec;
378  return CVecType{0.f};
379 #endif
380  }
381 };
382 
383 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
385 {
386  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
387  using ADataType = fp16_t;
388  using BDataType = fp16_t;
389  using CDataType = float;
390 
394 
395  static constexpr index_t kM = 16;
396  static constexpr index_t kN = 16;
397  static constexpr index_t kK = 32;
398 
399  static constexpr index_t kAMBlock = 1;
400  static constexpr index_t kBNBlock = 1;
401 
402  static constexpr index_t kAMLane = 16;
403  static constexpr index_t kBNLane = 16;
404  static constexpr index_t kABKLane = 4;
405  static constexpr index_t kABKPerLane = 8;
406 
407  static constexpr index_t kCMLane = 4;
408  static constexpr index_t kCNLane = 16;
409  static constexpr index_t kCM0PerLane = 1;
410  static constexpr index_t kCM1PerLane = 4;
411 
412  // c_vec += a_vec * b_vec
413  template <bool post_nop_ = false>
415  const AVecType& a_vec,
416  const BVecType& b_vec,
417  bool_constant<post_nop_> = {}) const
418  {
419  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32f16", Ctrl)
420  else
421  {
422 #if defined(__gfx950__)
423  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, c_vec, 0, 0, 0);
424 #else
425  ck_tile::ignore = c_vec;
426  ck_tile::ignore = a_vec;
427  ck_tile::ignore = b_vec;
428 #endif
429  }
430  }
431 
432  // c_vec = a_vec * b_vec
433  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
434  {
435 #if defined(__gfx950__)
436  return bit_cast<CVecType>(
437  __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
438 #else
439  ck_tile::ignore = a_vec;
440  ck_tile::ignore = b_vec;
441  return CVecType{0.f};
442 #endif
443  }
444 };
445 
446 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
448 {
449  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
450  using ADataType = fp16_t;
451  using BDataType = fp16_t;
452  using CDataType = float;
453 
457 
458  static constexpr index_t kM = 4;
459  static constexpr index_t kN = 64;
460  static constexpr index_t kK = 4;
461 
462  static constexpr index_t kAMBlock = 1;
463  static constexpr index_t kBNBlock = 16;
464 
465  // we only write down single block (4 threads) thread mapping here
466  static constexpr index_t kAMLane = 4;
467  static constexpr index_t kBNLane = 4;
468  static constexpr index_t kABKLane = 1;
469  static constexpr index_t kABKPerLane = 4;
470 
471  static constexpr index_t kCMLane = 1;
472  static constexpr index_t kCNLane = 4;
473  static constexpr index_t kCM0PerLane = 1;
474  static constexpr index_t kCM1PerLane = 4;
475 
476  // c_vec += a_vec * b_vec
477  template <bool post_nop_ = false>
479  const AVecType& a_vec,
480  const BVecType& b_vec,
481  bool_constant<post_nop_> = {}) const
482  {
483  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
484  else
485  {
486 #if defined(__gfx9__)
487  c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
488 #else
489  ignore = c_vec;
490  ignore = a_vec;
491  ignore = b_vec;
492 #endif
493  }
494  }
495 
496  // c_vec = a_vec * b_vec
497  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
498  {
499 #if defined(__gfx9__)
500  return bit_cast<CVecType>(
501  __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
502 #else
503  ignore = a_vec;
504  ignore = b_vec;
505  return CVecType{0.f};
506 #endif
507  }
508 };
509 
510 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
512 {
513  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
514  using ADataType = fp16_t;
515  using BDataType = fp16_t;
516  using CDataType = float;
517 
521 
522  static constexpr index_t kM = 64;
523  static constexpr index_t kN = 4;
524  static constexpr index_t kK = 4;
525 
526  static constexpr index_t kAMBlock = 16;
527  static constexpr index_t kBNBlock = 1;
528 
529  // we only write down single block (4 threads) thread mapping here
530  static constexpr index_t kAMLane = 4;
531  static constexpr index_t kBNLane = 4;
532  static constexpr index_t kABKLane = 1;
533  static constexpr index_t kABKPerLane = 4;
534 
535  static constexpr index_t kCMLane = 1;
536  static constexpr index_t kCNLane = 4;
537  static constexpr index_t kCM0PerLane = 1;
538  static constexpr index_t kCM1PerLane = 4;
539 
540  // c_vec += a_vec * b_vec
541  template <bool post_nop_ = false>
543  const AVecType& a_vec,
544  const BVecType& b_vec,
545  bool_constant<post_nop_> = {}) const
546  {
547  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
548  else
549  {
550 #if defined(__gfx9__)
551  c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
552 #else
553  ignore = c_vec;
554  ignore = a_vec;
555  ignore = b_vec;
556 #endif
557  }
558  }
559 
560  // c_vec = a_vec * b_vec
561  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
562  {
563 #if defined(__gfx9__)
564  return bit_cast<CVecType>(
565  __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
566 #else
567  ignore = a_vec;
568  ignore = b_vec;
569  return CVecType{0.f};
570 #endif
571  }
572 };
573 
574 // Bf16
575 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
577 {
578  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
579  using ADataType = bf16_t;
580  using BDataType = bf16_t;
581  using CDataType = float;
582 
586 
587  static constexpr index_t kM = 32;
588  static constexpr index_t kN = 32;
589  static constexpr index_t kK = 8;
590 
591  static constexpr index_t kAMBlock = 1;
592  static constexpr index_t kBNBlock = 1;
593 
594  static constexpr index_t kAMLane = 32;
595  static constexpr index_t kBNLane = 32;
596  static constexpr index_t kABKLane = 2;
597  static constexpr index_t kABKPerLane = 4;
598 
599  static constexpr index_t kCMLane = 2;
600  static constexpr index_t kCNLane = 32;
601  static constexpr index_t kCM0PerLane = 4;
602  static constexpr index_t kCM1PerLane = 4;
603 
604  // c_vec += a_vec * b_vec
605  template <bool post_nop_ = false>
607  const AVecType& a_vec,
608  const BVecType& b_vec,
609  bool_constant<post_nop_> = {}) const
610  {
611  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
612  else
613  {
614 #if defined(__gfx90a__) || defined(__gfx94__)
615  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
616 #elif defined(__gfx908__)
617  static_for<0, 2, 1>{}([&](auto k) {
618  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
619  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
620  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
621  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
622  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
623  c_vec,
624  0,
625  0,
626  0);
627  });
628 #else
629  ck_tile::ignore = c_vec;
630  ck_tile::ignore = a_vec;
631  ck_tile::ignore = b_vec;
632 #endif
633  }
634  }
635 
636  // c_vec = a_vec * b_vec
637  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
638  {
639 #if defined(__gfx90a__) || defined(__gfx94__)
640  return bit_cast<CVecType>(
641  __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
642 #elif defined(__gfx908__)
643  CVecType c_vec{0.f};
644  static_for<0, 2, 1>{}([&](auto k) {
645  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
646  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
647  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
648  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
649  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
650  c_vec,
651  0,
652  0,
653  0);
654  });
655  return c_vec;
656 #else
657  ck_tile::ignore = a_vec;
658  ck_tile::ignore = b_vec;
659  return CVecType{0.f};
660 #endif
661  }
662 };
663 
664 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
666 {
667  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
668  using ADataType = bf16_t;
669  using BDataType = bf16_t;
670  using CDataType = float;
671 
675 
676  static constexpr index_t kM = 16;
677  static constexpr index_t kN = 16;
678  static constexpr index_t kK = 16;
679 
680  static constexpr index_t kAMBlock = 1;
681  static constexpr index_t kBNBlock = 1;
682 
683  static constexpr index_t kAMLane = 16;
684  static constexpr index_t kBNLane = 16;
685  static constexpr index_t kABKLane = 4;
686  static constexpr index_t kABKPerLane = 4;
687 
688  static constexpr index_t kCMLane = 4;
689  static constexpr index_t kCNLane = 16;
690  static constexpr index_t kCM0PerLane = 1;
691  static constexpr index_t kCM1PerLane = 4;
692 
693  // c_vec += a_vec * b_vec
694  template <bool post_nop_ = false>
696  const AVecType& a_vec,
697  const BVecType& b_vec,
698  bool_constant<post_nop_> = {}) const
699  {
700  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
701  {
702 #if defined(__gfx90a__) || defined(__gfx94__)
703  c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
704 #elif defined(__gfx908__)
705  static_for<0, 2, 1>{}([&](auto k) {
706  c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
707  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
708  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
709  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
710  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
711  c_vec,
712  0,
713  0,
714  0);
715  });
716 #else
717  ck_tile::ignore = c_vec;
718  ck_tile::ignore = a_vec;
719  ck_tile::ignore = b_vec;
720 #endif
721  }
722  }
723 
724  // c_vec = a_vec * b_vec
725  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
726  {
727 #if defined(__gfx90a__) || defined(__gfx94__)
728  return bit_cast<CVecType>(
729  __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
730 #elif defined(__gfx908__)
731  CVecType c_vec{0.f};
732  static_for<0, 2, 1>{}([&](auto k) {
733  c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
734  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
735  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
736  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
737  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
738  c_vec,
739  0,
740  0,
741  0);
742  });
743  return c_vec;
744 #else
745  ck_tile::ignore = a_vec;
746  ck_tile::ignore = b_vec;
747  return CVecType{0.f};
748 #endif
749  }
750 };
751 
752 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
754 {
755  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
756  using ADataType = bf16_t;
757  using BDataType = bf16_t;
758  using CDataType = float;
759 
763 
764  static constexpr index_t kM = 4;
765  static constexpr index_t kN = 64;
766  static constexpr index_t kK = 4;
767 
768  static constexpr index_t kAMBlock = 1;
769  static constexpr index_t kBNBlock = 16;
770 
771  // we only write down single block (4 threads) thread mapping here
772  static constexpr index_t kAMLane = 4;
773  static constexpr index_t kBNLane = 4;
774  static constexpr index_t kABKLane = 1;
775  static constexpr index_t kABKPerLane = 4;
776 
777  static constexpr index_t kCMLane = 1;
778  static constexpr index_t kCNLane = 4;
779  static constexpr index_t kCM0PerLane = 1;
780  static constexpr index_t kCM1PerLane = 4;
781 
782  // c_vec += a_vec * b_vec
783  template <bool post_nop_ = false>
785  const AVecType& a_vec,
786  const BVecType& b_vec,
787  bool_constant<post_nop_> = {}) const
788  {
789  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
790  else
791  {
792 #if defined(__gfx90a__) || defined(__gfx94__)
793  c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
794 #elif defined(__gfx908__)
795  static_for<0, 2, 1>{}([&](auto k) {
796  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
797  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
798  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
799  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
800  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
801  c_vec,
802  0,
803  0,
804  0);
805  });
806 #else
807  ignore = c_vec;
808  ignore = a_vec;
809  ignore = b_vec;
810 #endif
811  }
812  }
813 
814  // c_vec = a_vec * b_vec
815  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
816  {
817 #if defined(__gfx90a__) || defined(__gfx94__)
818  return bit_cast<CVecType>(
819  __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
820 #elif defined(__gfx908__)
821  CVecType c_vec{0.f};
822  static_for<0, 2, 1>{}([&](auto k) {
823  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
824  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
825  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
826  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
827  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
828  c_vec,
829  0,
830  0,
831  0);
832  });
833  return c_vec;
834 #else
835  ignore = a_vec;
836  ignore = b_vec;
837  return CVecType{0.f};
838 #endif
839  }
840 };
841 
842 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
844 {
845  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
846  using ADataType = bf16_t;
847  using BDataType = bf16_t;
848  using CDataType = float;
849 
853 
854  static constexpr index_t kM = 64;
855  static constexpr index_t kN = 4;
856  static constexpr index_t kK = 4;
857 
858  static constexpr index_t kAMBlock = 16;
859  static constexpr index_t kBNBlock = 1;
860 
861  // we only write down single block (4 threads) thread mapping here
862  static constexpr index_t kAMLane = 4;
863  static constexpr index_t kBNLane = 4;
864  static constexpr index_t kABKLane = 1;
865  static constexpr index_t kABKPerLane = 4;
866 
867  static constexpr index_t kCMLane = 1;
868  static constexpr index_t kCNLane = 4;
869  static constexpr index_t kCM0PerLane = 1;
870  static constexpr index_t kCM1PerLane = 4;
871 
872  // c_vec += a_vec * b_vec
873  template <bool post_nop_ = false>
875  const AVecType& a_vec,
876  const BVecType& b_vec,
877  bool_constant<post_nop_> = {}) const
878  {
879  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
880  else
881  {
882 #if defined(__gfx90a__) || defined(__gfx94__)
883  c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
884 #elif defined(__gfx908__)
885  static_for<0, 2, 1>{}([&](auto k) {
886  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
887  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
888  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
889  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
890  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
891  c_vec,
892  0,
893  0,
894  0);
895  });
896 #else
897  ignore = c_vec;
898  ignore = a_vec;
899  ignore = b_vec;
900 #endif
901  }
902  }
903 
904  // c_vec = a_vec * b_vec
905  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
906  {
907 #if defined(__gfx90a__) || defined(__gfx94__)
908  return bit_cast<CVecType>(
909  __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
910 #elif defined(__gfx908__)
911  CVecType c_vec{0.f};
912  static_for<0, 2, 1>{}([&](auto k) {
913  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
914  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
915  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
916  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
917  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
918  c_vec,
919  0,
920  0,
921  0);
922  });
923  return c_vec;
924 #else
925  ignore = a_vec;
926  ignore = b_vec;
927  return CVecType{0.f};
928 #endif
929  }
930 };
931 
932 // gfx950
933 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
935 {
936  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
937  using ADataType = fp16_t;
938  using BDataType = fp16_t;
939  using CDataType = float;
940 
944 
945  static constexpr index_t kM = 32;
946  static constexpr index_t kN = 32;
947  static constexpr index_t kK = 16;
948 
949  static constexpr index_t kAMBlock = 1;
950  static constexpr index_t kBNBlock = 1;
951 
952  static constexpr index_t kAMLane = 32;
953  static constexpr index_t kBNLane = 32;
954  static constexpr index_t kABKLane = 2;
955  static constexpr index_t kABKPerLane = 8;
956 
957  static constexpr index_t kCMLane = 2;
958  static constexpr index_t kCNLane = 32;
959  static constexpr index_t kCM0PerLane = 4;
960  static constexpr index_t kCM1PerLane = 4;
961 
962  // c_vec += a_vec * b_vec
963  template <bool post_nop_ = false>
965  const AVecType& a_vec,
966  const BVecType& b_vec,
967  bool_constant<post_nop_> = {}) const
968  {
969  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_f16", Ctrl)
970  else
971  {
972 #if defined(__gfx950__)
973  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, c_vec, 0, 0, 0);
974 #elif defined(__gfx90a__) || defined(__gfx94__)
975  static_for<0, 2, 1>{}([&](auto k) {
976  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
977  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
978  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
979  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
980  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
981  c_vec,
982  0,
983  0,
984  0);
985  });
986 #elif defined(__gfx908__)
987  static_for<0, 4, 1>{}([&](auto k) {
988  c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
989  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
990  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
991  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
992  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
993  c_vec,
994  0,
995  0,
996  0);
997  });
998 #else
999  ck_tile::ignore = c_vec;
1000  ck_tile::ignore = a_vec;
1001  ck_tile::ignore = b_vec;
1002 #endif
1003  }
1004  }
1005 
1006  // c_vec = a_vec * b_vec
1007  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1008  {
1009 #if defined(__gfx950__)
1010  return __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
1011 #elif defined(__gfx90a__) || defined(__gfx94__)
1012  CVecType c_vec{0.f};
1013  static_for<0, 2, 1>{}([&](auto k) {
1014  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
1015  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1016  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
1017  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1018  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
1019  c_vec,
1020  0,
1021  0,
1022  0);
1023  });
1024  return c_vec;
1025 #elif defined(__gfx908__)
1026  CVecType c_vec{0.f};
1027  static_for<0, 4, 1>{}([&](auto k) {
1028  c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
1029  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
1030  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
1031  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
1032  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
1033  c_vec,
1034  0,
1035  0,
1036  0);
1037  });
1038  return c_vec;
1039 #else
1040  ck_tile::ignore = a_vec;
1041  ck_tile::ignore = b_vec;
1042  return CVecType{0.f};
1043 #endif
1044  }
1045 };
1046 
1047 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1049 {
1050  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1053  using CDataType = float;
1054 
1058 
1059  static constexpr index_t kM = 32;
1060  static constexpr index_t kN = 32;
1061  static constexpr index_t kK = 16;
1062 
1063  static constexpr index_t kAMBlock = 1;
1064  static constexpr index_t kBNBlock = 1;
1065 
1066  static constexpr index_t kAMLane = 32;
1067  static constexpr index_t kBNLane = 32;
1068  static constexpr index_t kABKLane = 2;
1069  static constexpr index_t kABKPerLane = 8;
1070 
1071  static constexpr index_t kCMLane = 2;
1072  static constexpr index_t kCNLane = 32;
1073  static constexpr index_t kCM0PerLane = 4;
1074  static constexpr index_t kCM1PerLane = 4;
1075 
1076  // c_vec += a_vec * b_vec
1077  template <bool post_nop_ = false>
1079  const AVecType& a_vec,
1080  const BVecType& b_vec,
1081  bool_constant<post_nop_> = {}) const
1082  {
1083  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_bf16", Ctrl)
1084  else
1085  {
1086 #if defined(__gfx950__)
1087  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, c_vec, 0, 0, 0);
1088 #elif defined(__gfx90a__) || defined(__gfx94__)
1089  static_for<0, 2, 1>{}([&](auto k) {
1090  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
1091  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1092  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1093  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1094  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1095  c_vec,
1096  0,
1097  0,
1098  0);
1099  });
1100 #elif defined(__gfx908__)
1101  static_for<0, 4, 1>{}([&](auto k) {
1102  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
1103  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
1104  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1105  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
1106  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1107  c_vec,
1108  0,
1109  0,
1110  0);
1111  });
1112 #else
1113  ck_tile::ignore = c_vec;
1114  ck_tile::ignore = a_vec;
1115  ck_tile::ignore = b_vec;
1116 #endif
1117  }
1118  }
1119 
1120  // c_vec = a_vec * b_vec
1121  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1122  {
1123 #if defined(__gfx950__)
1124  return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
1125 #elif defined(__gfx90a__) || defined(__gfx94__)
1126  CVecType c_vec{0.f};
1127  static_for<0, 2, 1>{}([&](auto k) {
1128  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
1129  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1130  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1131  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1132  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1133  c_vec,
1134  0,
1135  0,
1136  0);
1137  });
1138  return c_vec;
1139 #elif defined(__gfx908__)
1140  CVecType c_vec{0.f};
1141  static_for<0, 4, 1>{}([&](auto k) {
1142  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
1143  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
1144  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1145  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
1146  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1147  c_vec,
1148  0,
1149  0,
1150  0);
1151  });
1152  return c_vec;
1153 #else
1154  ck_tile::ignore = a_vec;
1155  ck_tile::ignore = b_vec;
1156  return CVecType{0.f};
1157 #endif
1158  }
1159 };
1160 
1161 // FP8
1162 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1164 {
1165  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1166  using ADataType = AType_;
1167  using BDataType = BType_;
1168  using CDataType = float;
1169 
1173 
1174  static constexpr index_t kM = 16;
1175  static constexpr index_t kN = 16;
1176  static constexpr index_t kK = 32;
1177 
1178  static constexpr index_t kAMBlock = 1;
1179  static constexpr index_t kBNBlock = 1;
1180 
1181  static constexpr index_t kAMLane = 16;
1182  static constexpr index_t kBNLane = 16;
1183  static constexpr index_t kABKLane = 4;
1184  static constexpr index_t kABKPerLane = 8;
1185 
1186  static constexpr index_t kCMLane = 4;
1187  static constexpr index_t kCNLane = 16;
1188  static constexpr index_t kCM0PerLane = 1;
1189  static constexpr index_t kCM1PerLane = 4;
1190 
1191  // c_vec += a_vec * b_vec
1192  template <bool post_nop_ = false>
1194  const AVecType& a_vec,
1195  const BVecType& b_vec,
1196  bool_constant<post_nop_> = {}) const
1197  {
1198  if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
1199  {
1200  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1201  {
1202  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "v", "v")
1203  }
1204  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1205  {
1206  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "v", "v")
1207  }
1208  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1209  {
1210  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "v", "v")
1211  }
1212  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1213  {
1214  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "v", "v")
1215  }
1216  }
1217  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
1218  {
1219  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1220  {
1221  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "a", "v")
1222  }
1223  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1224  {
1225  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "a", "v")
1226  }
1227  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1228  {
1229  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "a", "v")
1230  }
1231  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1232  {
1233  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "a", "v")
1234  }
1235  }
1236  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
1237  {
1238  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1239  {
1240  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "v", "v")
1241  }
1242  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1243  {
1244  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "v", "v")
1245  }
1246  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1247  {
1248  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "v", "v")
1249  }
1250  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1251  {
1252  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "v", "v")
1253  }
1254  }
1255  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
1256  {
1257  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1258  {
1259  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "a", "v")
1260  }
1261  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1262  {
1263  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "a", "v")
1264  }
1265  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1266  {
1267  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "a", "v")
1268  }
1269  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1270  {
1271  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "a", "v")
1272  }
1273  }
1274  else
1275  {
1276 #if defined(__gfx94__) or defined(__gfx95__)
1277  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1278  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1279  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1280  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1281  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1282  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1283  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1284  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1285  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1286  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1287  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1288  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1289 #else
1290  ck_tile::ignore = c_vec;
1291  ck_tile::ignore = a_vec;
1292  ck_tile::ignore = b_vec;
1293 #endif
1294  }
1295  }
1296 
1297  // c_vec = a_vec * b_vec
1298  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1299  {
1300 #if defined(__gfx94__) or defined(__gfx95__)
1301  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1302  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1303  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1304  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1305  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1306  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1307  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1308  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1309  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1310  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1311  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1312  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1313 #else
1314  ck_tile::ignore = a_vec;
1315  ck_tile::ignore = b_vec;
1316  return CVecType{0.f};
1317 #endif
1318  }
1319 };
1320 
1321 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1323 {
1324  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1325  using ADataType = AType_;
1326  using BDataType = BType_;
1327  using CDataType = float;
1328 
1332 
1333  static constexpr index_t kM = 32;
1334  static constexpr index_t kN = 32;
1335  static constexpr index_t kK = 16;
1336 
1337  static constexpr index_t kAMBlock = 1;
1338  static constexpr index_t kBNBlock = 1;
1339 
1340  static constexpr index_t kAMLane = 32;
1341  static constexpr index_t kBNLane = 32;
1342  static constexpr index_t kABKLane = 2;
1343  static constexpr index_t kABKPerLane = 8;
1344 
1345  static constexpr index_t kCMLane = 2;
1346  static constexpr index_t kCNLane = 32;
1347  static constexpr index_t kCM0PerLane = 4;
1348  static constexpr index_t kCM1PerLane = 4;
1349 
1350  // c_vec += a_vec * b_vec
1351  template <bool post_nop_ = false>
1353  const AVecType& a_vec,
1354  const BVecType& b_vec,
1355  bool_constant<post_nop_> = {}) const
1356  {
1357  if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
1358  {
1359  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1360  {
1361  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "v", "v")
1362  }
1363  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1364  {
1365  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "v", "v")
1366  }
1367  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1368  {
1369  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "v", "v")
1370  }
1371  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1372  {
1373  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "v", "v")
1374  }
1375  }
1376  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
1377  {
1378  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1379  {
1380  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "a", "v")
1381  }
1382  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1383  {
1384  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "a", "v")
1385  }
1386  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1387  {
1388  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "a", "v")
1389  }
1390  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1391  {
1392  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "a", "v")
1393  }
1394  }
1395  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
1396  {
1397  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1398  {
1399  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "v", "v")
1400  }
1401  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1402  {
1403  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "v", "v")
1404  }
1405  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1406  {
1407  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "v", "v")
1408  }
1409  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1410  {
1411  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "v", "v")
1412  }
1413  }
1414  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
1415  {
1416  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1417  {
1418  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "a", "v")
1419  }
1420  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1421  {
1422  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "a", "v")
1423  }
1424  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1425  {
1426  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "a", "v")
1427  }
1428  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1429  {
1430  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "a", "v")
1431  }
1432  }
1433  else
1434  {
1435 #if defined(__gfx94__) or defined(__gfx95__)
1436  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1437  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1438  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1439  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1440  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1441  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1442  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1443  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1444  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1445  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1446  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1447  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1448 #elif defined(__gfx908__) || defined(__gfx90a__)
1449  static_for<0, 8, 1>{}([&](auto k) {
1450  float a_f32 =
1451  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1452  .template get_as<ADataType>()[number<k>{}]);
1453  float b_f32 =
1454  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1455  .template get_as<BDataType>()[number<k>{}]);
1456 
1457  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1458  });
1459 #else
1460  ck_tile::ignore = c_vec;
1461  ck_tile::ignore = a_vec;
1462  ck_tile::ignore = b_vec;
1463 #endif
1464  }
1465  }
1466 
1467  // c_vec = a_vec * b_vec
1468  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1469  {
1470 #if defined(__gfx94__) or defined(__gfx95__)
1471  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1472  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1473  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1474  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1475  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1476  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1477  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1478  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1479  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1480  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1481  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1482  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1483 #elif defined(__gfx908__) || defined(__gfx90a__)
1484  CVecType c_vec{0.f};
1485  static_for<0, 8, 1>{}([&](auto k) {
1486  float a_f32 =
1487  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1488  .template get_as<ADataType>()[number<k>{}]);
1489  float b_f32 =
1490  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1491  .template get_as<BDataType>()[number<k>{}]);
1492 
1493  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1494  });
1495  return c_vec;
1496 #else
1497  ck_tile::ignore = a_vec;
1498  ck_tile::ignore = b_vec;
1499  return CVecType{0.f};
1500 #endif
1501  }
1502 };
1503 
1504 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1507 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1510 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1513 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1516 
1517 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1520 
1521 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1524 
1525 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1528 
1529 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1531 {
1532  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1533  using ADataType = AType_;
1534  using BDataType = BType_;
1535  using CDataType = float;
1536 
1540 
1541  static constexpr index_t kM = 16;
1542  static constexpr index_t kN = 16;
1543  static constexpr index_t kK = 128;
1544 
1545  static constexpr index_t kAMBlock = 1;
1546  static constexpr index_t kBNBlock = 1;
1547 
1548  static constexpr index_t kAMLane = 16;
1549  static constexpr index_t kBNLane = 16;
1550  static constexpr index_t kABKLane = 4;
1551  static constexpr index_t kABKPerLane = 32;
1552 
1553  static constexpr index_t kCMLane = 4;
1554  static constexpr index_t kCNLane = 16;
1555  static constexpr index_t kCM0PerLane = 1;
1556  static constexpr index_t kCM1PerLane = 4;
1557 
1558  // c_vec += a_vec * b_vec
1559  template <bool post_nop_ = false>
1561  const AVecType& a_vec,
1562  const BVecType& b_vec,
1563  bool_constant<post_nop_> = {}) const
1564  {
1565  //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
1566  // opsel, scale_b)
1567 #if defined(__gfx950__)
1568  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1569  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1570  a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
1571  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1572  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1573  a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
1574  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1575  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1576  a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
1577  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1578  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1579  a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
1580 #else
1581  ck_tile::ignore = c_vec;
1582  ck_tile::ignore = a_vec;
1583  ck_tile::ignore = b_vec;
1584 #endif
1585  }
1586 
1587  // c_vec = a_vec * b_vec
1588  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1589  {
1590 #if defined(__gfx950__)
1591  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1592  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1593  a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
1594  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1595  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1596  a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
1597  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1598  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1599  a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
1600  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1601  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1602  a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
1603 #else
1604  ck_tile::ignore = a_vec;
1605  ck_tile::ignore = b_vec;
1606  return CVecType{0.f};
1607 #endif
1608  }
1609 };
1610 
1611 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1614 
1615 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1618 
1619 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1622 
1623 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1626 
1627 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1629 {
1630  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1633  using CDataType = float;
1634 
1638 
1639  static constexpr index_t kM = 16;
1640  static constexpr index_t kN = 16;
1641  static constexpr index_t kK = 128;
1642 
1643  static constexpr index_t kAMBlock = 1;
1644  static constexpr index_t kBNBlock = 1;
1645 
1646  static constexpr index_t kAMLane = 16;
1647  static constexpr index_t kBNLane = 16;
1648  static constexpr index_t kABKLane = 4;
1649  static constexpr index_t kABKPerLane = 32;
1650 
1651  static constexpr index_t kCMLane = 4;
1652  static constexpr index_t kCNLane = 16;
1653  static constexpr index_t kCM0PerLane = 1;
1654  static constexpr index_t kCM1PerLane = 4;
1655 
1656  // c_vec += a_vec * b_vec
1657  template <index_t opselA, index_t opselB, bool post_nop_ = false>
1659  const AVecType& a_vec,
1660  const int32_t& a_scale,
1661  const BVecType& b_vec,
1662  const int32_t& b_scale,
1663  bool_constant<post_nop_> = {}) const
1664  {
1665  //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
1666  // opsel, scale_b)
1667 #if defined(__gfx950__)
1668  auto arg_a = bit_cast<int32x4_t>(a_vec);
1669  auto arg_b = bit_cast<int32x4_t>(b_vec);
1670  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1671  int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1672  int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1673  c_vec,
1674  4,
1675  4,
1676  opselA,
1677  a_scale,
1678  opselB,
1679  b_scale);
1680 #else
1681  ck_tile::ignore = c_vec;
1682  ck_tile::ignore = a_vec;
1683  ck_tile::ignore = b_vec;
1684  ck_tile::ignore = a_scale;
1685  ck_tile::ignore = b_scale;
1686 #endif
1687  }
1688 
1689  // c_vec = a_vec * b_vec
1690  template <index_t opselA, index_t opselB>
1692  const int32_t& a_scale,
1693  const BVecType& b_vec,
1694  const int32_t& b_scale) const
1695  {
1696 #if defined(__gfx950__)
1697  auto arg_a = bit_cast<int32x4_t>(a_vec);
1698  auto arg_b = bit_cast<int32x4_t>(b_vec);
1699  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1700  int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1701  int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1702  CVecType{0.f},
1703  4,
1704  4,
1705  opselA,
1706  a_scale,
1707  opselB,
1708  b_scale));
1709 #else
1710  ck_tile::ignore = a_vec;
1711  ck_tile::ignore = b_vec;
1712  ck_tile::ignore = a_scale;
1713  ck_tile::ignore = b_scale;
1714  return CVecType{0.f};
1715 #endif
1716  }
1717 };
1718 
1719 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1721 {
1722  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1723  using ADataType = AType_;
1724  using BDataType = BType_;
1725  using CDataType = float;
1726 
1730 
1731  static constexpr index_t kM = 32;
1732  static constexpr index_t kN = 32;
1733  static constexpr index_t kK = 64;
1734 
1735  static constexpr index_t kAMBlock = 1;
1736  static constexpr index_t kBNBlock = 1;
1737 
1738  static constexpr index_t kAMLane = 32;
1739  static constexpr index_t kBNLane = 32;
1740  static constexpr index_t kABKLane = 2;
1741  static constexpr index_t kABKPerLane = 32;
1742 
1743  static constexpr index_t kCMLane = 2;
1744  static constexpr index_t kCNLane = 32;
1745  static constexpr index_t kCM0PerLane = 4;
1746  static constexpr index_t kCM1PerLane = 4;
1747 
1748  // c_vec += a_vec * b_vec
1749  template <bool post_nop_ = false>
1751  const AVecType& a_vec,
1752  const BVecType& b_vec,
1753  bool_constant<post_nop_> = {}) const
1754  {
1755  //__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
1756  // opsel, scale_b)
1757 #if defined(__gfx950__)
1758  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1759  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1760  a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
1761  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1762  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1763  a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
1764  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1765  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1766  a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
1767  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1768  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1769  a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
1770 #else
1771  ck_tile::ignore = c_vec;
1772  ck_tile::ignore = a_vec;
1773  ck_tile::ignore = b_vec;
1774 #endif
1775  }
1776 
1777  // c_vec = a_vec * b_vec
1778  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1779  {
1780 #if defined(__gfx950__)
1781  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1782  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1783  a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
1784  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1785  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1786  a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
1787  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1788  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1789  a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
1790  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1791  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1792  a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
1793 #else
1794  ck_tile::ignore = a_vec;
1795  ck_tile::ignore = b_vec;
1796  return CVecType{0.f};
1797 #endif
1798  }
1799 };
1800 
1801 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1804 
1805 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1808 
1809 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1812 
1813 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1816 
1817 // int8
1818 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1820 {
1821  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1825 
1829 
1830  static constexpr index_t kM = 32;
1831  static constexpr index_t kN = 32;
1832  static constexpr index_t kK = 16;
1833 
1834  static constexpr index_t kAMBlock = 1;
1835  static constexpr index_t kBNBlock = 1;
1836 
1837  static constexpr index_t kAMLane = 32;
1838  static constexpr index_t kBNLane = 32;
1839  static constexpr index_t kABKLane = 2;
1840  static constexpr index_t kABKPerLane = 8;
1841 
1842  static constexpr index_t kCMLane = 2;
1843  static constexpr index_t kCNLane = 32;
1844  static constexpr index_t kCM0PerLane = 4;
1845  static constexpr index_t kCM1PerLane = 4;
1846 
1847  // c_vec += a_vec * b_vec
1848  template <bool post_nop_ = false>
1850  const AVecType& a_vec,
1851  const BVecType& b_vec,
1852  bool_constant<post_nop_> = {}) const
1853  {
1854  DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
1855  else
1856  {
1857 #if defined(__gfx94__) or defined(__gfx95__)
1858  c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
1859  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1860 #elif defined(__gfx908__) || defined(__gfx90a__)
1861  static_for<0, 8, 1>{}([&](auto k) {
1862  float a_f32 =
1863  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1864  .template get_as<ADataType>()[number<k>{}]);
1865  float b_f32 =
1866  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1867  .template get_as<BDataType>()[number<k>{}]);
1868 
1869  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1870  });
1871 #else
1872  ck_tile::ignore = c_vec;
1873  ck_tile::ignore = a_vec;
1874  ck_tile::ignore = b_vec;
1875 #endif
1876  }
1877  }
1878 
1879  // c_vec = a_vec * b_vec
1880  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1881  {
1882  CVecType c_vec{0};
1883  operator()(c_vec, a_vec, b_vec);
1884  return c_vec;
1885  }
1886 };
1887 
1888 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1890 {
1891  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1895 
1899 
1900  static constexpr index_t kM = 16;
1901  static constexpr index_t kN = 16;
1902  static constexpr index_t kK = 32;
1903 
1904  static constexpr index_t kAMBlock = 1;
1905  static constexpr index_t kBNBlock = 1;
1906 
1907  static constexpr index_t kAMLane = 16;
1908  static constexpr index_t kBNLane = 16;
1909  static constexpr index_t kABKLane = 4;
1910  static constexpr index_t kABKPerLane = 8;
1911 
1912  static constexpr index_t kCMLane = 4;
1913  static constexpr index_t kCNLane = 16;
1914  static constexpr index_t kCM0PerLane = 1;
1915  static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
1916 
1917  // c_vec += a_vec * b_vec
1918  template <bool post_nop_ = false>
1920  const AVecType& a_vec,
1921  const BVecType& b_vec,
1922  bool_constant<post_nop_> = {}) const
1923  {
1924  DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x32_i8", Ctrl)
1925  else
1926  {
1927 #if defined(__gfx94__) or defined(__gfx95__)
1928  c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
1929  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1930 #else
1931  ck_tile::ignore = c_vec;
1932  ck_tile::ignore = a_vec;
1933  ck_tile::ignore = b_vec;
1934 #endif
1935  }
1936  }
1937 
1938  // c_vec = a_vec * b_vec
1939  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1940  {
1941  CVecType c_vec{0};
1942  operator()(c_vec, a_vec, b_vec);
1943  return c_vec;
1944  }
1945 };
1946 
1947 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1949 {
1950  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1954 
1958 
1959  static constexpr index_t kM = 16;
1960  static constexpr index_t kN = 16;
1961  static constexpr index_t kK = 64;
1962 
1963  static constexpr index_t kAMBlock = 1;
1964  static constexpr index_t kBNBlock = 1;
1965 
1966  static constexpr index_t kAMLane = 16;
1967  static constexpr index_t kBNLane = 16;
1968  static constexpr index_t kABKLane = 4;
1969  static constexpr index_t kABKPerLane = 16;
1970 
1971  static constexpr index_t kCMLane = 4;
1972  static constexpr index_t kCNLane = 16;
1973  static constexpr index_t kCM0PerLane = 1;
1974  static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
1975 
1976  // c_vec += a_vec * b_vec
1977  template <bool post_nop_ = false>
1979  const AVecType& a_vec,
1980  const BVecType& b_vec,
1981  bool_constant<post_nop_> = {}) const
1982  {
1983  DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x64_i8", Ctrl)
1984  else
1985  {
1986 #if defined(__gfx95__)
1987  c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
1988  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1989 #else
1990  ck_tile::ignore = c_vec;
1991  ck_tile::ignore = a_vec;
1992  ck_tile::ignore = b_vec;
1993 #endif
1994  }
1995  }
1996 
1997  // c_vec = a_vec * b_vec
1998  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1999  {
2000  CVecType c_vec{0};
2001  operator()(c_vec, a_vec, b_vec);
2002  return c_vec;
2003  }
2004 };
2005 
2006 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
2008 {
2009  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
2013 
2017 
2018  static constexpr index_t kM = 32;
2019  static constexpr index_t kN = 32;
2020  static constexpr index_t kK = 32;
2021 
2022  static constexpr index_t kAMBlock = 1;
2023  static constexpr index_t kBNBlock = 1;
2024 
2025  static constexpr index_t kAMLane = 32;
2026  static constexpr index_t kBNLane = 32;
2027  static constexpr index_t kABKLane = 2;
2028  static constexpr index_t kABKPerLane = 16;
2029 
2030  static constexpr index_t kCMLane = 2;
2031  static constexpr index_t kCNLane = 32;
2032  static constexpr index_t kCM0PerLane = 4;
2033  static constexpr index_t kCM1PerLane = 4;
2034 
2035  // c_vec += a_vec * b_vec
2036  template <bool post_nop_ = false>
2038  const AVecType& a_vec,
2039  const BVecType& b_vec,
2040  bool_constant<post_nop_> = {}) const
2041  {
2042  DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x32_i8", Ctrl)
2043  else
2044  {
2045 #if defined(__gfx95__)
2046  c_vec = __builtin_amdgcn_mfma_i32_32x32x32_i8(
2047  a_vec, bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
2048 #else
2049  ck_tile::ignore = c_vec;
2050  ck_tile::ignore = a_vec;
2051  ck_tile::ignore = b_vec;
2052 #endif
2053  }
2054  }
2055 
2056  // c_vec = a_vec * b_vec
2057  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
2058  {
2059  CVecType c_vec{0};
2060  operator()(c_vec, a_vec, b_vec);
2061  return c_vec;
2062  }
2063 };
2064 
2065 #undef DISPATCH_MFMA_
2066 
2067 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
WGAttrCtlEnum
Definition: warp_gemm_attribute_mfma_impl.hpp:15
int32_t int32x8_t
Definition: vector_type.hpp:156
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
pk_float4_e2m1_t pk_fp4_t
Definition: pk_fp4.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:84
int32_t int32_t
Definition: integer.hpp:10
float fp32x16_t
Definition: vector_type.hpp:130
float fp32x4_t
Definition: vector_type.hpp:128
Definition: warp_gemm_attribute_mfma_impl.hpp:1531
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1541
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1555
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1546
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1532
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1539
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1550
ext_vector_t< ADataType, 32 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1537
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1553
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1535
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1548
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1534
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1556
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1543
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1551
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1545
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1554
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1549
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1542
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1560
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1588
ext_vector_t< BDataType, 32 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1538
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1533
Definition: warp_gemm_attribute_mfma_impl.hpp:1164
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1178
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1179
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1188
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1174
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1166
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1172
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1170
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1165
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1186
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1183
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1193
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1189
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1168
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1298
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1182
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1167
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1187
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1184
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1176
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1181
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1175
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1171
Definition: warp_gemm_attribute_mfma_impl.hpp:1323
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1345
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1324
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1329
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1326
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1348
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1331
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1343
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1327
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1337
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1346
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1352
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1468
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1340
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1333
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1325
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1330
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1335
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1334
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1347
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1338
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1342
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1341
Definition: warp_gemm_attribute_mfma_impl.hpp:1721
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1778
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1732
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1743
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1724
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1741
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1731
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1739
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1738
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1729
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1725
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1750
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1745
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1736
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1735
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1744
ext_vector_t< BDataType, 32 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1728
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1722
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1723
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1740
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1746
ext_vector_t< ADataType, 32 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1727
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1733
Definition: warp_gemm_attribute_mfma_impl.hpp:1890
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1939
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1898
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1897
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1896
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1907
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1919
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1891
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1901
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1905
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1892
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1908
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1910
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1909
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1904
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1912
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1914
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1900
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1915
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1913
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1894
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1902
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1893
Definition: warp_gemm_attribute_mfma_impl.hpp:1949
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1953
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1974
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1961
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1960
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1967
ext_vector_t< ADataType, 16 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1955
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1969
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1973
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1959
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1966
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1978
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1972
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1964
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1968
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1950
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1952
ext_vector_t< BDataType, 16 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1956
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1998
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1951
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1971
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1957
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1963
Definition: warp_gemm_attribute_mfma_impl.hpp:1820
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1822
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1849
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1842
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1826
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1827
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1844
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1824
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1821
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1880
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1823
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1837
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1831
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1835
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1832
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1845
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1828
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1843
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1840
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1834
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1838
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1839
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1830
Definition: warp_gemm_attribute_mfma_impl.hpp:2008
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:2018
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:2010
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:2032
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:2027
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:2037
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:2011
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:2031
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:2012
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:2026
ext_vector_t< BDataType, 16 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:2015
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:2023
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:2025
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:2030
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:2020
ext_vector_t< ADataType, 16 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:2014
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:2016
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:2057
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:2019
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:2033
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:2028
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:2009
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:2022
Definition: warp_gemm_attribute_mfma_impl.hpp:666
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:674
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:684
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:673
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:695
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:688
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:725
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:678
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:683
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:667
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:689
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:690
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:685
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:680
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:670
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:691
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:677
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:669
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:681
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:668
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:672
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:686
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:676
Definition: warp_gemm_attribute_mfma_impl.hpp:196
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:219
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:199
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:204
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:207
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:210
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:220
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:215
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:206
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:208
ext_vector_t< bf16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:202
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:197
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:216
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:200
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:214
ext_vector_t< bf16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:203
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:218
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:198
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:225
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:221
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:244
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:211
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:213
Definition: warp_gemm_attribute_mfma_impl.hpp:1049
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1072
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1078
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1053
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1050
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1052
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1063
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1064
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1067
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1074
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1121
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1068
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1073
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1066
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1059
ext_vector_t< bf16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1055
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1060
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1071
ext_vector_t< bf16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1056
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1051
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1069
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1057
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1061
Definition: warp_gemm_attribute_mfma_impl.hpp:577
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:584
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:580
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:581
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:592
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:585
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:583
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:637
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:588
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:601
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:587
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:600
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:591
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:596
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:595
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:594
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:589
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:599
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:579
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:578
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:606
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:602
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:597
Definition: warp_gemm_attribute_mfma_impl.hpp:754
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:775
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:756
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:778
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:777
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:779
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:815
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:757
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:766
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:755
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:773
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:784
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:760
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:765
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:764
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:774
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:769
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:761
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:780
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:762
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:772
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:768
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:758
Definition: warp_gemm_attribute_mfma_impl.hpp:844
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:847
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:867
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:862
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:905
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:845
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:868
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:855
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:846
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:858
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:869
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:859
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:863
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:856
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:848
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:864
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:870
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:865
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:850
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:874
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:851
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:854
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:852
Definition: warp_gemm_attribute_mfma_impl.hpp:322
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:323
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:339
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:326
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:345
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:336
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:325
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:344
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:347
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:328
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:337
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:333
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:324
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:341
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:340
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:330
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:346
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:342
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:370
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:351
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:332
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:334
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:329
Definition: warp_gemm_attribute_mfma_impl.hpp:385
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:396
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:408
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:410
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:387
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:407
ext_vector_t< fp16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:391
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:404
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:405
ext_vector_t< fp16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:392
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:395
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:393
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:433
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:388
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:389
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:402
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:409
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:400
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:399
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:386
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:403
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:397
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:414
Definition: warp_gemm_attribute_mfma_impl.hpp:935
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:936
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:937
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:949
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:957
ext_vector_t< fp16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:941
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:958
ext_vector_t< fp16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:942
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:945
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:939
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:964
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:946
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:955
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1007
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:954
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:953
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:960
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:950
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:947
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:952
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:959
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:943
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:938
Definition: warp_gemm_attribute_mfma_impl.hpp:259
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:282
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:276
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:283
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:281
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:288
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:284
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:263
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:271
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:262
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:277
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:267
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:274
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:260
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:269
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:266
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:270
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:307
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:273
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:265
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:261
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:278
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:279
Definition: warp_gemm_attribute_mfma_impl.hpp:448
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:459
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:471
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:450
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:454
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:497
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:455
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:463
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:452
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:478
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:456
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:466
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:467
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:472
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:451
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:460
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:468
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:458
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:473
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:449
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:469
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:474
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:462
Definition: warp_gemm_attribute_mfma_impl.hpp:512
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:533
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:516
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:522
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:530
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:515
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:513
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:527
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:542
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:524
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:526
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:520
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:523
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:561
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:535
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:518
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:531
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:536
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:537
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:532
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:514
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:519
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:538
Definition: warp_gemm_attribute_mfma_impl.hpp:67
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:68
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:87
ext_vector_t< ADataType, 1 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:74
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:80
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:78
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:97
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:90
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:93
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:85
float ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:70
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:76
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:83
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:82
float BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:71
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:92
ext_vector_t< BDataType, 1 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:75
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:88
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:72
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:116
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:91
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:79
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:86
Definition: warp_gemm_attribute_mfma_impl.hpp:131
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:156
float BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:135
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:142
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:157
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:155
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:132
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:161
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:136
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:149
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:140
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:154
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:150
ext_vector_t< ADataType, 1 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:138
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:180
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:152
ext_vector_t< BDataType, 1 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:139
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:143
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:146
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:151
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:147
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:144
float ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:134
Definition: warp_gemm_attribute_mfma_impl.hpp:1629
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1648
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1641
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const int32_t &a_scale, const BVecType &b_vec, const int32_t &b_scale) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1691
ext_vector_t< BDataType, 16 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1636
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1633
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1637
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1640
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1649
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1639
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1652
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1651
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1647
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const int32_t &a_scale, const BVecType &b_vec, const int32_t &b_scale, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1658
ext_vector_t< ADataType, 16 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1635
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1653
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1644
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1646
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1630
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1643
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1654
Definition: integral_constant.hpp:13
Definition: pk_fp4.hpp:76
Definition: functional.hpp:43
Definition: debug.hpp:67
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_)
Definition: warp_gemm_attribute_mfma_impl.hpp:25
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_)
Definition: warp_gemm_attribute_mfma_impl.hpp:42