/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File
warp_gemm_attribute_mfma_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 // TODO: refactor warp-gemm
11 // currently there is a discrepency for vav/vva if we need transpose C/D
12 // e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
13 // because we swap the A/B pointer in _impl code (but not known this info here)
14 enum class WGAttrCtlEnum
15 {
16  Default_ = 0,
17  Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
18  Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
19  Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
20  Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
21  Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
22  // raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
23 };
24 
25 #define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
26  if constexpr(post_nop_) \
27  { \
28  asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
29  "s_nop 3" \
30  : dmod_(c_vec) \
31  : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
32  :); \
33  } \
34  else \
35  { \
36  asm volatile(mfma_ " %0, %1, %2, %3\n" \
37  : dmod_(c_vec) \
38  : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
39  :); \
40  }
41 
42 #define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
43  if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
44  { \
45  DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
46  } \
47  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
48  { \
49  DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
50  } \
51  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
52  { \
53  DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
54  } \
55  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
56  { \
57  DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
58  } \
59  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
60  { \
61  DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
62  }
63 
64 // V_MFMA_F32_16x16x32_BF16
65 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
67 {
68  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
69  using ADataType = bf16_t;
70  using BDataType = bf16_t;
71  using CDataType = float;
72 
76 
77  static constexpr index_t kM = 16;
78  static constexpr index_t kN = 16;
79  static constexpr index_t kK = 32;
80 
81  static constexpr index_t kAMBlock = 1;
82  static constexpr index_t kBNBlock = 1;
83 
84  static constexpr index_t kAMLane = 16;
85  static constexpr index_t kBNLane = 16;
86  static constexpr index_t kABKLane = 4;
87  static constexpr index_t kABKPerLane = 8;
88 
89  static constexpr index_t kCMLane = 4;
90  static constexpr index_t kCNLane = 16;
91  static constexpr index_t kCM0PerLane = 1;
92  static constexpr index_t kCM1PerLane = 4;
93 
94  // c_vec += a_vec * b_vec
95  template <bool post_nop_ = false>
97  const AVecType& a_vec,
98  const BVecType& b_vec,
99  bool_constant<post_nop_> = {}) const
100  {
101  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32_bf16", Ctrl)
102  else
103  {
104 #if defined(__gfx950__)
105  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, c_vec, 0, 0, 0);
106 #else
107  ck_tile::ignore = c_vec;
108  ck_tile::ignore = a_vec;
109  ck_tile::ignore = b_vec;
110 #endif
111  }
112  }
113 
114  // c_vec = a_vec * b_vec
115  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
116  {
117 #if defined(__gfx950__)
118  return bit_cast<CVecType>(
119  __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
120 #else
121  ck_tile::ignore = a_vec;
122  ck_tile::ignore = b_vec;
123  return CVecType{0.f};
124 #endif
125  }
126 };
127 // FP16
128 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
130 {
131  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
132  using ADataType = fp16_t;
133  using BDataType = fp16_t;
134  using CDataType = float;
135 
139 
140  static constexpr index_t kM = 32;
141  static constexpr index_t kN = 32;
142  static constexpr index_t kK = 8;
143 
144  static constexpr index_t kAMBlock = 1;
145  static constexpr index_t kBNBlock = 1;
146 
147  static constexpr index_t kAMLane = 32;
148  static constexpr index_t kBNLane = 32;
149  static constexpr index_t kABKLane = 2;
150  static constexpr index_t kABKPerLane = 4;
151 
152  static constexpr index_t kCMLane = 2;
153  static constexpr index_t kCNLane = 32;
154  static constexpr index_t kCM0PerLane = 4;
155  static constexpr index_t kCM1PerLane = 4;
156 
157  // c_vec += a_vec * b_vec
158  template <bool post_nop_ = false>
160  const AVecType& a_vec,
161  const BVecType& b_vec,
162  bool_constant<post_nop_> = {}) const
163  {
164  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
165  else
166  {
167 #if defined(__gfx9__)
168  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
169 #else
170  ck_tile::ignore = c_vec;
171  ck_tile::ignore = a_vec;
172  ck_tile::ignore = b_vec;
173 #endif
174  }
175  }
176 
177  // c_vec = a_vec * b_vec
178  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
179  {
180 #if defined(__gfx9__)
181  return bit_cast<CVecType>(
182  __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
183 #else
184  ck_tile::ignore = a_vec;
185  ck_tile::ignore = b_vec;
186  return CVecType{0.f};
187 #endif
188  }
189 };
190 
191 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
193 {
194  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
195  using ADataType = fp16_t;
196  using BDataType = fp16_t;
197  using CDataType = float;
198 
202 
203  static constexpr index_t kM = 16;
204  static constexpr index_t kN = 16;
205  static constexpr index_t kK = 16;
206 
207  static constexpr index_t kAMBlock = 1;
208  static constexpr index_t kBNBlock = 1;
209 
210  static constexpr index_t kAMLane = 16;
211  static constexpr index_t kBNLane = 16;
212  static constexpr index_t kABKLane = 4;
213  static constexpr index_t kABKPerLane = 4;
214 
215  static constexpr index_t kCMLane = 4;
216  static constexpr index_t kCNLane = 16;
217  static constexpr index_t kCM0PerLane = 1;
218  static constexpr index_t kCM1PerLane = 4;
219 
220  // c_vec += a_vec * b_vec
221  template <bool post_nop_ = false>
223  const AVecType& a_vec,
224  const BVecType& b_vec,
225  bool_constant<post_nop_> = {}) const
226  {
227  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
228  else
229  {
230 #if defined(__gfx9__)
231  c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
232 #else
233  ck_tile::ignore = c_vec;
234  ck_tile::ignore = a_vec;
235  ck_tile::ignore = b_vec;
236 #endif
237  }
238  }
239 
240  // c_vec = a_vec * b_vec
241  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
242  {
243 #if defined(__gfx9__)
244  return bit_cast<CVecType>(
245  __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
246 #else
247  ck_tile::ignore = a_vec;
248  ck_tile::ignore = b_vec;
249  return CVecType{0.f};
250 #endif
251  }
252 };
253 
254 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
256 {
257  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
258  using ADataType = fp16_t;
259  using BDataType = fp16_t;
260  using CDataType = float;
261 
265 
266  static constexpr index_t kM = 16;
267  static constexpr index_t kN = 16;
268  static constexpr index_t kK = 32;
269 
270  static constexpr index_t kAMBlock = 1;
271  static constexpr index_t kBNBlock = 1;
272 
273  static constexpr index_t kAMLane = 16;
274  static constexpr index_t kBNLane = 16;
275  static constexpr index_t kABKLane = 4;
276  static constexpr index_t kABKPerLane = 8;
277 
278  static constexpr index_t kCMLane = 4;
279  static constexpr index_t kCNLane = 16;
280  static constexpr index_t kCM0PerLane = 1;
281  static constexpr index_t kCM1PerLane = 4;
282 
283  // c_vec += a_vec * b_vec
284  template <bool post_nop_ = false>
286  const AVecType& a_vec,
287  const BVecType& b_vec,
288  bool_constant<post_nop_> = {}) const
289  {
290  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32f16", Ctrl)
291  else
292  {
293 #if defined(__gfx950__)
294  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, c_vec, 0, 0, 0);
295 #else
296  ck_tile::ignore = c_vec;
297  ck_tile::ignore = a_vec;
298  ck_tile::ignore = b_vec;
299 #endif
300  }
301  }
302 
303  // c_vec = a_vec * b_vec
304  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
305  {
306 #if defined(__gfx950__)
307  return bit_cast<CVecType>(
308  __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
309 #else
310  ck_tile::ignore = a_vec;
311  ck_tile::ignore = b_vec;
312  return CVecType{0.f};
313 #endif
314  }
315 };
316 
317 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
319 {
320  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
321  using ADataType = fp16_t;
322  using BDataType = fp16_t;
323  using CDataType = float;
324 
328 
329  static constexpr index_t kM = 4;
330  static constexpr index_t kN = 64;
331  static constexpr index_t kK = 4;
332 
333  static constexpr index_t kAMBlock = 1;
334  static constexpr index_t kBNBlock = 16;
335 
336  // we only write down single block (4 threads) thread mapping here
337  static constexpr index_t kAMLane = 4;
338  static constexpr index_t kBNLane = 4;
339  static constexpr index_t kABKLane = 1;
340  static constexpr index_t kABKPerLane = 4;
341 
342  static constexpr index_t kCMLane = 1;
343  static constexpr index_t kCNLane = 4;
344  static constexpr index_t kCM0PerLane = 1;
345  static constexpr index_t kCM1PerLane = 4;
346 
347  // c_vec += a_vec * b_vec
348  template <bool post_nop_ = false>
350  const AVecType& a_vec,
351  const BVecType& b_vec,
352  bool_constant<post_nop_> = {}) const
353  {
354  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
355  else
356  {
357 #if defined(__gfx9__)
358  c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
359 #else
360  ignore = c_vec;
361  ignore = a_vec;
362  ignore = b_vec;
363 #endif
364  }
365  }
366 
367  // c_vec = a_vec * b_vec
368  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
369  {
370 #if defined(__gfx9__)
371  return bit_cast<CVecType>(
372  __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
373 #else
374  ignore = a_vec;
375  ignore = b_vec;
376  return CVecType{0.f};
377 #endif
378  }
379 };
380 
381 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
383 {
384  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
385  using ADataType = fp16_t;
386  using BDataType = fp16_t;
387  using CDataType = float;
388 
392 
393  static constexpr index_t kM = 64;
394  static constexpr index_t kN = 4;
395  static constexpr index_t kK = 4;
396 
397  static constexpr index_t kAMBlock = 16;
398  static constexpr index_t kBNBlock = 1;
399 
400  // we only write down single block (4 threads) thread mapping here
401  static constexpr index_t kAMLane = 4;
402  static constexpr index_t kBNLane = 4;
403  static constexpr index_t kABKLane = 1;
404  static constexpr index_t kABKPerLane = 4;
405 
406  static constexpr index_t kCMLane = 1;
407  static constexpr index_t kCNLane = 4;
408  static constexpr index_t kCM0PerLane = 1;
409  static constexpr index_t kCM1PerLane = 4;
410 
411  // c_vec += a_vec * b_vec
412  template <bool post_nop_ = false>
414  const AVecType& a_vec,
415  const BVecType& b_vec,
416  bool_constant<post_nop_> = {}) const
417  {
418  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
419  else
420  {
421 #if defined(__gfx9__)
422  c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
423 #else
424  ignore = c_vec;
425  ignore = a_vec;
426  ignore = b_vec;
427 #endif
428  }
429  }
430 
431  // c_vec = a_vec * b_vec
432  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
433  {
434 #if defined(__gfx9__)
435  return bit_cast<CVecType>(
436  __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
437 #else
438  ignore = a_vec;
439  ignore = b_vec;
440  return CVecType{0.f};
441 #endif
442  }
443 };
444 
445 // Bf16
446 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
448 {
449  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
450  using ADataType = bf16_t;
451  using BDataType = bf16_t;
452  using CDataType = float;
453 
457 
458  static constexpr index_t kM = 32;
459  static constexpr index_t kN = 32;
460  static constexpr index_t kK = 8;
461 
462  static constexpr index_t kAMBlock = 1;
463  static constexpr index_t kBNBlock = 1;
464 
465  static constexpr index_t kAMLane = 32;
466  static constexpr index_t kBNLane = 32;
467  static constexpr index_t kABKLane = 2;
468  static constexpr index_t kABKPerLane = 4;
469 
470  static constexpr index_t kCMLane = 2;
471  static constexpr index_t kCNLane = 32;
472  static constexpr index_t kCM0PerLane = 4;
473  static constexpr index_t kCM1PerLane = 4;
474 
475  // c_vec += a_vec * b_vec
476  template <bool post_nop_ = false>
478  const AVecType& a_vec,
479  const BVecType& b_vec,
480  bool_constant<post_nop_> = {}) const
481  {
482  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
483  else
484  {
485 #if defined(__gfx90a__) || defined(__gfx94__)
486  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
487 #elif defined(__gfx908__)
488  static_for<0, 2, 1>{}([&](auto k) {
489  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
490  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
491  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
492  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
493  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
494  c_vec,
495  0,
496  0,
497  0);
498  });
499 #else
500  ck_tile::ignore = c_vec;
501  ck_tile::ignore = a_vec;
502  ck_tile::ignore = b_vec;
503 #endif
504  }
505  }
506 
507  // c_vec = a_vec * b_vec
508  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
509  {
510 #if defined(__gfx90a__) || defined(__gfx94__)
511  return bit_cast<CVecType>(
512  __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
513 #elif defined(__gfx908__)
514  CVecType c_vec{0.f};
515  static_for<0, 2, 1>{}([&](auto k) {
516  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
517  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
518  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
519  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
520  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
521  c_vec,
522  0,
523  0,
524  0);
525  });
526  return c_vec;
527 #else
528  ck_tile::ignore = a_vec;
529  ck_tile::ignore = b_vec;
530  return CVecType{0.f};
531 #endif
532  }
533 };
534 
535 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
537 {
538  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
539  using ADataType = bf16_t;
540  using BDataType = bf16_t;
541  using CDataType = float;
542 
546 
547  static constexpr index_t kM = 16;
548  static constexpr index_t kN = 16;
549  static constexpr index_t kK = 16;
550 
551  static constexpr index_t kAMBlock = 1;
552  static constexpr index_t kBNBlock = 1;
553 
554  static constexpr index_t kAMLane = 16;
555  static constexpr index_t kBNLane = 16;
556  static constexpr index_t kABKLane = 4;
557  static constexpr index_t kABKPerLane = 4;
558 
559  static constexpr index_t kCMLane = 4;
560  static constexpr index_t kCNLane = 16;
561  static constexpr index_t kCM0PerLane = 1;
562  static constexpr index_t kCM1PerLane = 4;
563 
564  // c_vec += a_vec * b_vec
565  template <bool post_nop_ = false>
567  const AVecType& a_vec,
568  const BVecType& b_vec,
569  bool_constant<post_nop_> = {}) const
570  {
571  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
572  {
573 #if defined(__gfx90a__) || defined(__gfx94__)
574  c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
575 #elif defined(__gfx908__)
576  static_for<0, 2, 1>{}([&](auto k) {
577  c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
578  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
579  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
580  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
581  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
582  c_vec,
583  0,
584  0,
585  0);
586  });
587 #else
588  ck_tile::ignore = c_vec;
589  ck_tile::ignore = a_vec;
590  ck_tile::ignore = b_vec;
591 #endif
592  }
593  }
594 
595  // c_vec = a_vec * b_vec
596  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
597  {
598 #if defined(__gfx90a__) || defined(__gfx94__)
599  return bit_cast<CVecType>(
600  __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
601 #elif defined(__gfx908__)
602  CVecType c_vec{0.f};
603  static_for<0, 2, 1>{}([&](auto k) {
604  c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
605  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
606  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
607  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
608  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
609  c_vec,
610  0,
611  0,
612  0);
613  });
614  return c_vec;
615 #else
616  ck_tile::ignore = a_vec;
617  ck_tile::ignore = b_vec;
618  return CVecType{0.f};
619 #endif
620  }
621 };
622 
623 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
625 {
626  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
627  using ADataType = bf16_t;
628  using BDataType = bf16_t;
629  using CDataType = float;
630 
634 
635  static constexpr index_t kM = 4;
636  static constexpr index_t kN = 64;
637  static constexpr index_t kK = 4;
638 
639  static constexpr index_t kAMBlock = 1;
640  static constexpr index_t kBNBlock = 16;
641 
642  // we only write down single block (4 threads) thread mapping here
643  static constexpr index_t kAMLane = 4;
644  static constexpr index_t kBNLane = 4;
645  static constexpr index_t kABKLane = 1;
646  static constexpr index_t kABKPerLane = 4;
647 
648  static constexpr index_t kCMLane = 1;
649  static constexpr index_t kCNLane = 4;
650  static constexpr index_t kCM0PerLane = 1;
651  static constexpr index_t kCM1PerLane = 4;
652 
653  // c_vec += a_vec * b_vec
654  template <bool post_nop_ = false>
656  const AVecType& a_vec,
657  const BVecType& b_vec,
658  bool_constant<post_nop_> = {}) const
659  {
660  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
661  else
662  {
663 #if defined(__gfx90a__) || defined(__gfx94__)
664  c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
665 #elif defined(__gfx908__)
666  static_for<0, 2, 1>{}([&](auto k) {
667  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
668  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
669  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
670  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
671  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
672  c_vec,
673  0,
674  0,
675  0);
676  });
677 #else
678  ignore = c_vec;
679  ignore = a_vec;
680  ignore = b_vec;
681 #endif
682  }
683  }
684 
685  // c_vec = a_vec * b_vec
686  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
687  {
688 #if defined(__gfx90a__) || defined(__gfx94__)
689  return bit_cast<CVecType>(
690  __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
691 #elif defined(__gfx908__)
692  CVecType c_vec{0.f};
693  static_for<0, 2, 1>{}([&](auto k) {
694  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
695  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
696  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
697  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
698  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
699  c_vec,
700  0,
701  0,
702  0);
703  });
704  return c_vec;
705 #else
706  ignore = a_vec;
707  ignore = b_vec;
708  return CVecType{0.f};
709 #endif
710  }
711 };
712 
713 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
715 {
716  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
717  using ADataType = bf16_t;
718  using BDataType = bf16_t;
719  using CDataType = float;
720 
724 
725  static constexpr index_t kM = 64;
726  static constexpr index_t kN = 4;
727  static constexpr index_t kK = 4;
728 
729  static constexpr index_t kAMBlock = 16;
730  static constexpr index_t kBNBlock = 1;
731 
732  // we only write down single block (4 threads) thread mapping here
733  static constexpr index_t kAMLane = 4;
734  static constexpr index_t kBNLane = 4;
735  static constexpr index_t kABKLane = 1;
736  static constexpr index_t kABKPerLane = 4;
737 
738  static constexpr index_t kCMLane = 1;
739  static constexpr index_t kCNLane = 4;
740  static constexpr index_t kCM0PerLane = 1;
741  static constexpr index_t kCM1PerLane = 4;
742 
743  // c_vec += a_vec * b_vec
744  template <bool post_nop_ = false>
746  const AVecType& a_vec,
747  const BVecType& b_vec,
748  bool_constant<post_nop_> = {}) const
749  {
750  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
751  else
752  {
753 #if defined(__gfx90a__) || defined(__gfx94__)
754  c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
755 #elif defined(__gfx908__)
756  static_for<0, 2, 1>{}([&](auto k) {
757  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
758  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
759  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
760  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
761  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
762  c_vec,
763  0,
764  0,
765  0);
766  });
767 #else
768  ignore = c_vec;
769  ignore = a_vec;
770  ignore = b_vec;
771 #endif
772  }
773  }
774 
775  // c_vec = a_vec * b_vec
776  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
777  {
778 #if defined(__gfx90a__) || defined(__gfx94__)
779  return bit_cast<CVecType>(
780  __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
781 #elif defined(__gfx908__)
782  CVecType c_vec{0.f};
783  static_for<0, 2, 1>{}([&](auto k) {
784  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
785  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
786  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
787  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
788  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
789  c_vec,
790  0,
791  0,
792  0);
793  });
794  return c_vec;
795 #else
796  ignore = a_vec;
797  ignore = b_vec;
798  return CVecType{0.f};
799 #endif
800  }
801 };
802 
803 // gfx950
804 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
806 {
807  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
808  using ADataType = fp16_t;
809  using BDataType = fp16_t;
810  using CDataType = float;
811 
815 
816  static constexpr index_t kM = 32;
817  static constexpr index_t kN = 32;
818  static constexpr index_t kK = 16;
819 
820  static constexpr index_t kAMBlock = 1;
821  static constexpr index_t kBNBlock = 1;
822 
823  static constexpr index_t kAMLane = 32;
824  static constexpr index_t kBNLane = 32;
825  static constexpr index_t kABKLane = 2;
826  static constexpr index_t kABKPerLane = 8;
827 
828  static constexpr index_t kCMLane = 2;
829  static constexpr index_t kCNLane = 32;
830  static constexpr index_t kCM0PerLane = 4;
831  static constexpr index_t kCM1PerLane = 4;
832 
833  // c_vec += a_vec * b_vec
834  template <bool post_nop_ = false>
836  const AVecType& a_vec,
837  const BVecType& b_vec,
838  bool_constant<post_nop_> = {}) const
839  {
840  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_f16", Ctrl)
841  else
842  {
843 #if defined(__gfx950__)
844  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, c_vec, 0, 0, 0);
845 #elif defined(__gfx90a__) || defined(__gfx94__)
846  static_for<0, 2, 1>{}([&](auto k) {
847  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
848  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
849  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
850  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
851  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
852  c_vec,
853  0,
854  0,
855  0);
856  });
857 #elif defined(__gfx908__)
858  static_for<0, 4, 1>{}([&](auto k) {
859  c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
860  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
861  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
862  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
863  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
864  c_vec,
865  0,
866  0,
867  0);
868  });
869 #else
870  ck_tile::ignore = c_vec;
871  ck_tile::ignore = a_vec;
872  ck_tile::ignore = b_vec;
873 #endif
874  }
875  }
876 
877  // c_vec = a_vec * b_vec
878  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
879  {
880 #if defined(__gfx950__)
881  return __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
882 #elif defined(__gfx90a__) || defined(__gfx94__)
883  CVecType c_vec{0.f};
884  static_for<0, 2, 1>{}([&](auto k) {
885  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
886  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
887  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
888  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
889  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
890  c_vec,
891  0,
892  0,
893  0);
894  });
895  return c_vec;
896 #elif defined(__gfx908__)
897  CVecType c_vec{0.f};
898  static_for<0, 4, 1>{}([&](auto k) {
899  c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
900  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
901  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
902  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
903  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
904  c_vec,
905  0,
906  0,
907  0);
908  });
909  return c_vec;
910 #else
911  ck_tile::ignore = a_vec;
912  ck_tile::ignore = b_vec;
913  return CVecType{0.f};
914 #endif
915  }
916 };
917 
918 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
920 {
921  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
922  using ADataType = bf16_t;
923  using BDataType = bf16_t;
924  using CDataType = float;
925 
929 
930  static constexpr index_t kM = 32;
931  static constexpr index_t kN = 32;
932  static constexpr index_t kK = 16;
933 
934  static constexpr index_t kAMBlock = 1;
935  static constexpr index_t kBNBlock = 1;
936 
937  static constexpr index_t kAMLane = 32;
938  static constexpr index_t kBNLane = 32;
939  static constexpr index_t kABKLane = 2;
940  static constexpr index_t kABKPerLane = 8;
941 
942  static constexpr index_t kCMLane = 2;
943  static constexpr index_t kCNLane = 32;
944  static constexpr index_t kCM0PerLane = 4;
945  static constexpr index_t kCM1PerLane = 4;
946 
947  // c_vec += a_vec * b_vec
948  template <bool post_nop_ = false>
950  const AVecType& a_vec,
951  const BVecType& b_vec,
952  bool_constant<post_nop_> = {}) const
953  {
954  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_bf16", Ctrl)
955  else
956  {
957 #if defined(__gfx950__)
958  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, c_vec, 0, 0, 0);
959 #elif defined(__gfx90a__) || defined(__gfx94__)
960  static_for<0, 2, 1>{}([&](auto k) {
961  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
962  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
963  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
964  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
965  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
966  c_vec,
967  0,
968  0,
969  0);
970  });
971 #elif defined(__gfx908__)
972  static_for<0, 4, 1>{}([&](auto k) {
973  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
974  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
975  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
976  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
977  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
978  c_vec,
979  0,
980  0,
981  0);
982  });
983 #else
984  ck_tile::ignore = c_vec;
985  ck_tile::ignore = a_vec;
986  ck_tile::ignore = b_vec;
987 #endif
988  }
989  }
990 
991  // c_vec = a_vec * b_vec
992  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
993  {
994 #if defined(__gfx950__)
995  return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
996 #elif defined(__gfx90a__) || defined(__gfx94__)
997  CVecType c_vec{0.f};
998  static_for<0, 2, 1>{}([&](auto k) {
999  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
1000  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1001  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1002  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1003  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1004  c_vec,
1005  0,
1006  0,
1007  0);
1008  });
1009  return c_vec;
1010 #elif defined(__gfx908__)
1011  CVecType c_vec{0.f};
1012  static_for<0, 4, 1>{}([&](auto k) {
1013  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
1014  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
1015  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1016  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
1017  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1018  c_vec,
1019  0,
1020  0,
1021  0);
1022  });
1023  return c_vec;
1024 #else
1025  ck_tile::ignore = a_vec;
1026  ck_tile::ignore = b_vec;
1027  return CVecType{0.f};
1028 #endif
1029  }
1030 };
1031 
1032 // FP8
1033 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1035 {
1036  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1037  using ADataType = AType_;
1038  using BDataType = BType_;
1039  using CDataType = float;
1040 
1044 
1045  static constexpr index_t kM = 16;
1046  static constexpr index_t kN = 16;
1047  static constexpr index_t kK = 32;
1048 
1049  static constexpr index_t kAMBlock = 1;
1050  static constexpr index_t kBNBlock = 1;
1051 
1052  static constexpr index_t kAMLane = 16;
1053  static constexpr index_t kBNLane = 16;
1054  static constexpr index_t kABKLane = 4;
1055  static constexpr index_t kABKPerLane = 8;
1056 
1057  static constexpr index_t kCMLane = 4;
1058  static constexpr index_t kCNLane = 16;
1059  static constexpr index_t kCM0PerLane = 1;
1060  static constexpr index_t kCM1PerLane = 4;
1061 
1062  // c_vec += a_vec * b_vec
1063  template <bool post_nop_ = false>
1065  const AVecType& a_vec,
1066  const BVecType& b_vec,
1067  bool_constant<post_nop_> = {}) const
1068  {
1069  if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
1070  {
1071  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1072  {
1073  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "v", "v")
1074  }
1075  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1076  {
1077  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "v", "v")
1078  }
1079  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1080  {
1081  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "v", "v")
1082  }
1083  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1084  {
1085  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "v", "v")
1086  }
1087  }
1088  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
1089  {
1090  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1091  {
1092  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "a", "v")
1093  }
1094  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1095  {
1096  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "a", "v")
1097  }
1098  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1099  {
1100  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "a", "v")
1101  }
1102  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1103  {
1104  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "a", "v")
1105  }
1106  }
1107  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
1108  {
1109  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1110  {
1111  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "v", "v")
1112  }
1113  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1114  {
1115  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "v", "v")
1116  }
1117  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1118  {
1119  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "v", "v")
1120  }
1121  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1122  {
1123  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "v", "v")
1124  }
1125  }
1126  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
1127  {
1128  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1129  {
1130  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "a", "v")
1131  }
1132  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1133  {
1134  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "a", "v")
1135  }
1136  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1137  {
1138  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "a", "v")
1139  }
1140  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1141  {
1142  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "a", "v")
1143  }
1144  }
1145  else
1146  {
1147 #if defined(__gfx94__) or defined(__gfx95__)
1148  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1149  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1150  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1151  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1152  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1153  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1154  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1155  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1156  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1157  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1158  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1159  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1160 #else
1161  ck_tile::ignore = c_vec;
1162  ck_tile::ignore = a_vec;
1163  ck_tile::ignore = b_vec;
1164 #endif
1165  }
1166  }
1167 
1168  // c_vec = a_vec * b_vec
1169  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1170  {
1171 #if defined(__gfx94__) or defined(__gfx95__)
1172  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1173  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1174  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1175  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1176  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1177  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1178  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1179  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1180  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1181  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1182  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1183  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1184 #else
1185  ck_tile::ignore = a_vec;
1186  ck_tile::ignore = b_vec;
1187  return CVecType{0.f};
1188 #endif
1189  }
1190 };
1191 
1192 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1194 {
1195  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1196  using ADataType = AType_;
1197  using BDataType = BType_;
1198  using CDataType = float;
1199 
1203 
1204  static constexpr index_t kM = 32;
1205  static constexpr index_t kN = 32;
1206  static constexpr index_t kK = 16;
1207 
1208  static constexpr index_t kAMBlock = 1;
1209  static constexpr index_t kBNBlock = 1;
1210 
1211  static constexpr index_t kAMLane = 32;
1212  static constexpr index_t kBNLane = 32;
1213  static constexpr index_t kABKLane = 2;
1214  static constexpr index_t kABKPerLane = 8;
1215 
1216  static constexpr index_t kCMLane = 2;
1217  static constexpr index_t kCNLane = 32;
1218  static constexpr index_t kCM0PerLane = 4;
1219  static constexpr index_t kCM1PerLane = 4;
1220 
1221  // c_vec += a_vec * b_vec
1222  template <bool post_nop_ = false>
1224  const AVecType& a_vec,
1225  const BVecType& b_vec,
1226  bool_constant<post_nop_> = {}) const
1227  {
1228  if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
1229  {
1230  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1231  {
1232  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "v", "v")
1233  }
1234  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1235  {
1236  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "v", "v")
1237  }
1238  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1239  {
1240  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "v", "v")
1241  }
1242  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1243  {
1244  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "v", "v")
1245  }
1246  }
1247  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
1248  {
1249  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1250  {
1251  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "a", "v")
1252  }
1253  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1254  {
1255  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "a", "v")
1256  }
1257  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1258  {
1259  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "a", "v")
1260  }
1261  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1262  {
1263  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "a", "v")
1264  }
1265  }
1266  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
1267  {
1268  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1269  {
1270  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "v", "v")
1271  }
1272  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1273  {
1274  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "v", "v")
1275  }
1276  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1277  {
1278  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "v", "v")
1279  }
1280  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1281  {
1282  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "v", "v")
1283  }
1284  }
1285  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
1286  {
1287  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1288  {
1289  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "a", "v")
1290  }
1291  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1292  {
1293  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "a", "v")
1294  }
1295  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1296  {
1297  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "a", "v")
1298  }
1299  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1300  {
1301  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "a", "v")
1302  }
1303  }
1304  else
1305  {
1306 #if defined(__gfx94__) or defined(__gfx95__)
1307  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1308  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1309  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1310  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1311  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1312  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1313  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1314  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1315  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1316  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1317  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1318  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1319 #elif defined(__gfx908__) || defined(__gfx90a__)
1320  static_for<0, 8, 1>{}([&](auto k) {
1321  float a_f32 =
1322  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1323  .template get_as<ADataType>()[number<k>{}]);
1324  float b_f32 =
1325  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1326  .template get_as<BDataType>()[number<k>{}]);
1327 
1328  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1329  });
1330 #else
1331  ck_tile::ignore = c_vec;
1332  ck_tile::ignore = a_vec;
1333  ck_tile::ignore = b_vec;
1334 #endif
1335  }
1336  }
1337 
1338  // c_vec = a_vec * b_vec
1339  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1340  {
1341 #if defined(__gfx94__) or defined(__gfx95__)
1342  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1343  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1344  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1345  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1346  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1347  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1348  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1349  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1350  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1351  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1352  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1353  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1354 #elif defined(__gfx908__) || defined(__gfx90a__)
1355  CVecType c_vec{0.f};
1356  static_for<0, 8, 1>{}([&](auto k) {
1357  float a_f32 =
1358  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1359  .template get_as<ADataType>()[number<k>{}]);
1360  float b_f32 =
1361  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1362  .template get_as<BDataType>()[number<k>{}]);
1363 
1364  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1365  });
1366  return c_vec;
1367 #else
1368  ck_tile::ignore = a_vec;
1369  ck_tile::ignore = b_vec;
1370  return CVecType{0.f};
1371 #endif
1372  }
1373 };
1374 
1375 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1378 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1381 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1384 
1385 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1388 
1389 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1392 
1393 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1396 
1397 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1399 {
1400  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1401  using ADataType = AType_;
1402  using BDataType = BType_;
1403  using CDataType = float;
1404 
1408 
1409  static constexpr index_t kM = 16;
1410  static constexpr index_t kN = 16;
1411  static constexpr index_t kK = 128;
1412 
1413  static constexpr index_t kAMBlock = 1;
1414  static constexpr index_t kBNBlock = 1;
1415 
1416  static constexpr index_t kAMLane = 16;
1417  static constexpr index_t kBNLane = 16;
1418  static constexpr index_t kABKLane = 4;
1419  static constexpr index_t kABKPerLane = 32;
1420 
1421  static constexpr index_t kCMLane = 4;
1422  static constexpr index_t kCNLane = 16;
1423  static constexpr index_t kCM0PerLane = 1;
1424  static constexpr index_t kCM1PerLane = 4;
1425 
1426  // c_vec += a_vec * b_vec
1427  template <bool post_nop_ = false>
1429  const AVecType& a_vec,
1430  const BVecType& b_vec,
1431  bool_constant<post_nop_> = {}) const
1432  {
1433  //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
1434  // opsel, scale_b)
1435 #if defined(__gfx950__)
1436  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1437  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1438  a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
1439  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1440  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1441  a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
1442  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1443  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1444  a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
1445  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1446  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1447  a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
1448 #else
1449  ck_tile::ignore = c_vec;
1450  ck_tile::ignore = a_vec;
1451  ck_tile::ignore = b_vec;
1452 #endif
1453  }
1454 
1455  // c_vec = a_vec * b_vec
1456  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1457  {
1458 #if defined(__gfx950__)
1459  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1460  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1461  a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
1462  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1463  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1464  a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
1465  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1466  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1467  a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
1468  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1469  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1470  a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
1471 #else
1472  ck_tile::ignore = a_vec;
1473  ck_tile::ignore = b_vec;
1474  return CVecType{0.f};
1475 #endif
1476  }
1477 };
1478 
1479 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1482 
1483 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1486 
1487 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1490 
1491 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1494 
1495 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1497 {
1498  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1499  using ADataType = AType_;
1500  using BDataType = BType_;
1501  using CDataType = float;
1502 
1506 
1507  static constexpr index_t kM = 32;
1508  static constexpr index_t kN = 32;
1509  static constexpr index_t kK = 64;
1510 
1511  static constexpr index_t kAMBlock = 1;
1512  static constexpr index_t kBNBlock = 1;
1513 
1514  static constexpr index_t kAMLane = 32;
1515  static constexpr index_t kBNLane = 32;
1516  static constexpr index_t kABKLane = 2;
1517  static constexpr index_t kABKPerLane = 32;
1518 
1519  static constexpr index_t kCMLane = 2;
1520  static constexpr index_t kCNLane = 32;
1521  static constexpr index_t kCM0PerLane = 4;
1522  static constexpr index_t kCM1PerLane = 4;
1523 
1524  // c_vec += a_vec * b_vec
1525  template <bool post_nop_ = false>
1527  const AVecType& a_vec,
1528  const BVecType& b_vec,
1529  bool_constant<post_nop_> = {}) const
1530  {
1531  //__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
1532  // opsel, scale_b)
1533 #if defined(__gfx950__)
1534  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1535  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1536  a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
1537  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1538  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1539  a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
1540  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1541  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1542  a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
1543  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1544  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1545  a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
1546 #else
1547  ck_tile::ignore = c_vec;
1548  ck_tile::ignore = a_vec;
1549  ck_tile::ignore = b_vec;
1550 #endif
1551  }
1552 
1553  // c_vec = a_vec * b_vec
1554  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1555  {
1556 #if defined(__gfx950__)
1557  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1558  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1559  a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
1560  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1561  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1562  a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
1563  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1564  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1565  a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
1566  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1567  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1568  a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
1569 #else
1570  ck_tile::ignore = a_vec;
1571  ck_tile::ignore = b_vec;
1572  return CVecType{0.f};
1573 #endif
1574  }
1575 };
1576 
1577 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1580 
1581 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1584 
1585 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1588 
1589 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1592 
1593 // int8
1594 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1596 {
1597  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1601 
1605 
1606  static constexpr index_t kM = 32;
1607  static constexpr index_t kN = 32;
1608  static constexpr index_t kK = 16;
1609 
1610  static constexpr index_t kAMBlock = 1;
1611  static constexpr index_t kBNBlock = 1;
1612 
1613  static constexpr index_t kAMLane = 32;
1614  static constexpr index_t kBNLane = 32;
1615  static constexpr index_t kABKLane = 2;
1616  static constexpr index_t kABKPerLane = 8;
1617 
1618  static constexpr index_t kCMLane = 2;
1619  static constexpr index_t kCNLane = 32;
1620  static constexpr index_t kCM0PerLane = 4;
1621  static constexpr index_t kCM1PerLane = 4;
1622 
1623  // c_vec += a_vec * b_vec
1624  template <bool post_nop_ = false>
1626  const AVecType& a_vec,
1627  const BVecType& b_vec,
1628  bool_constant<post_nop_> = {}) const
1629  {
1630  DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
1631  else
1632  {
1633 #if defined(__gfx94__) or defined(__gfx95__)
1634  c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
1635  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1636 #elif defined(__gfx908__) || defined(__gfx90a__)
1637  static_for<0, 8, 1>{}([&](auto k) {
1638  float a_f32 =
1639  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1640  .template get_as<ADataType>()[number<k>{}]);
1641  float b_f32 =
1642  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1643  .template get_as<BDataType>()[number<k>{}]);
1644 
1645  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1646  });
1647 #else
1648  ck_tile::ignore = c_vec;
1649  ck_tile::ignore = a_vec;
1650  ck_tile::ignore = b_vec;
1651 #endif
1652  }
1653  }
1654 
1655  // c_vec = a_vec * b_vec
1656  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1657  {
1658  CVecType c_vec{0};
1659  operator()(c_vec, a_vec, b_vec);
1660  return c_vec;
1661  }
1662 };
1663 
1664 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1666 {
1667  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1671 
1675 
1676  static constexpr index_t kM = 16;
1677  static constexpr index_t kN = 16;
1678  static constexpr index_t kK = 32;
1679 
1680  static constexpr index_t kAMBlock = 1;
1681  static constexpr index_t kBNBlock = 1;
1682 
1683  static constexpr index_t kAMLane = 16;
1684  static constexpr index_t kBNLane = 16;
1685  static constexpr index_t kABKLane = 4;
1686  static constexpr index_t kABKPerLane = 8;
1687 
1688  static constexpr index_t kCMLane = 4;
1689  static constexpr index_t kCNLane = 16;
1690  static constexpr index_t kCM0PerLane = 1;
1691  static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
1692 
1693  // c_vec += a_vec * b_vec
1694  template <bool post_nop_ = false>
1696  const AVecType& a_vec,
1697  const BVecType& b_vec,
1698  bool_constant<post_nop_> = {}) const
1699  {
1700  DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x32_i8", Ctrl)
1701  else
1702  {
1703 #if defined(__gfx94__) or defined(__gfx95__)
1704  c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
1705  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1706 #else
1707  ck_tile::ignore = c_vec;
1708  ck_tile::ignore = a_vec;
1709  ck_tile::ignore = b_vec;
1710 #endif
1711  }
1712  }
1713 
1714  // c_vec = a_vec * b_vec
1715  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1716  {
1717  CVecType c_vec{0};
1718  operator()(c_vec, a_vec, b_vec);
1719  return c_vec;
1720  }
1721 };
1722 
1723 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1725 {
1726  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1730 
1734 
1735  static constexpr index_t kM = 16;
1736  static constexpr index_t kN = 16;
1737  static constexpr index_t kK = 64;
1738 
1739  static constexpr index_t kAMBlock = 1;
1740  static constexpr index_t kBNBlock = 1;
1741 
1742  static constexpr index_t kAMLane = 16;
1743  static constexpr index_t kBNLane = 16;
1744  static constexpr index_t kABKLane = 4;
1745  static constexpr index_t kABKPerLane = 16;
1746 
1747  static constexpr index_t kCMLane = 4;
1748  static constexpr index_t kCNLane = 16;
1749  static constexpr index_t kCM0PerLane = 1;
1750  static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
1751 
1752  // c_vec += a_vec * b_vec
1753  template <bool post_nop_ = false>
1755  const AVecType& a_vec,
1756  const BVecType& b_vec,
1757  bool_constant<post_nop_> = {}) const
1758  {
1759  DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x64_i8", Ctrl)
1760  else
1761  {
1762 #if defined(__gfx95__)
1763  c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
1764  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1765 #else
1766  ck_tile::ignore = c_vec;
1767  ck_tile::ignore = a_vec;
1768  ck_tile::ignore = b_vec;
1769 #endif
1770  }
1771  }
1772 
1773  // c_vec = a_vec * b_vec
1774  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1775  {
1776  CVecType c_vec{0};
1777  operator()(c_vec, a_vec, b_vec);
1778  return c_vec;
1779  }
1780 };
1781 
1782 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1784 {
1785  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1789 
1793 
1794  static constexpr index_t kM = 32;
1795  static constexpr index_t kN = 32;
1796  static constexpr index_t kK = 32;
1797 
1798  static constexpr index_t kAMBlock = 1;
1799  static constexpr index_t kBNBlock = 1;
1800 
1801  static constexpr index_t kAMLane = 32;
1802  static constexpr index_t kBNLane = 32;
1803  static constexpr index_t kABKLane = 2;
1804  static constexpr index_t kABKPerLane = 16;
1805 
1806  static constexpr index_t kCMLane = 2;
1807  static constexpr index_t kCNLane = 32;
1808  static constexpr index_t kCM0PerLane = 4;
1809  static constexpr index_t kCM1PerLane = 4;
1810 
1811  // c_vec += a_vec * b_vec
1812  template <bool post_nop_ = false>
1814  const AVecType& a_vec,
1815  const BVecType& b_vec,
1816  bool_constant<post_nop_> = {}) const
1817  {
1818  DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x32_i8", Ctrl)
1819  else
1820  {
1821 #if defined(__gfx95__)
1822  c_vec = __builtin_amdgcn_mfma_i32_32x32x32_i8(
1823  a_vec, bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1824 #else
1825  ck_tile::ignore = c_vec;
1826  ck_tile::ignore = a_vec;
1827  ck_tile::ignore = b_vec;
1828 #endif
1829  }
1830  }
1831 
1832  // c_vec = a_vec * b_vec
1833  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1834  {
1835  CVecType c_vec{0};
1836  operator()(c_vec, a_vec, b_vec);
1837  return c_vec;
1838  }
1839 };
1840 
1841 #undef DISPATCH_MFMA_
1842 
1843 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
WGAttrCtlEnum
Definition: warp_gemm_attribute_mfma_impl.hpp:15
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:83
int32_t int32_t
Definition: integer.hpp:10
float fp32x16_t
Definition: vector_type.hpp:119
float fp32x4_t
Definition: vector_type.hpp:117
Definition: warp_gemm_attribute_mfma_impl.hpp:1399
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1409
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1423
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1414
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1400
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1407
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1418
ext_vector_t< ADataType, 32 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1405
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1421
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1403
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1416
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1402
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1424
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1411
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1419
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1413
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1422
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1417
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1410
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1428
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1456
ext_vector_t< BDataType, 32 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1406
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1401
Definition: warp_gemm_attribute_mfma_impl.hpp:1035
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1049
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1050
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1059
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1045
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1037
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1043
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1041
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1036
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1057
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1054
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1064
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1060
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1039
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1169
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1053
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1038
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1058
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1055
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1047
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1052
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1046
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1042
Definition: warp_gemm_attribute_mfma_impl.hpp:1194
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1216
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1195
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1200
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1197
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1219
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1202
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1214
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1198
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1208
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1217
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1223
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1339
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1211
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1204
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1196
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1201
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1206
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1205
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1218
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1209
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1213
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1212
Definition: warp_gemm_attribute_mfma_impl.hpp:1497
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1554
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1508
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1519
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1500
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1517
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1507
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1515
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1514
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1505
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1501
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1526
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1521
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1512
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1511
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1520
ext_vector_t< BDataType, 32 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1504
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1498
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1499
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1516
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1522
ext_vector_t< ADataType, 32 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1503
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1509
Definition: warp_gemm_attribute_mfma_impl.hpp:1666
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1715
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1674
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1673
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1672
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1683
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1695
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1667
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1677
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1681
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1668
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1684
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1686
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1685
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1680
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1688
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1690
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1676
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1691
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1689
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1670
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1678
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1669
Definition: warp_gemm_attribute_mfma_impl.hpp:1725
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1729
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1750
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1737
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1736
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1743
ext_vector_t< ADataType, 16 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1731
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1745
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1749
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1735
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1742
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1754
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1748
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1740
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1744
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1726
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1728
ext_vector_t< BDataType, 16 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1732
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1774
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1727
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1747
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1733
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1739
Definition: warp_gemm_attribute_mfma_impl.hpp:1596
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1598
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1625
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1618
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1602
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1603
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1620
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1600
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1597
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1656
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1599
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1613
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1607
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1611
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1608
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1621
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1604
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1619
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1616
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1610
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1614
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1615
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1606
Definition: warp_gemm_attribute_mfma_impl.hpp:1784
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1794
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1786
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1808
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1803
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1813
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1787
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1807
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1788
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1802
ext_vector_t< BDataType, 16 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1791
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1799
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1801
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1806
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1796
ext_vector_t< ADataType, 16 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1790
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1792
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1833
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1795
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1809
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1804
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1785
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1798
Definition: warp_gemm_attribute_mfma_impl.hpp:537
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:545
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:555
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:544
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:566
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:559
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:596
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:549
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:554
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:538
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:560
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:561
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:556
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:551
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:541
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:562
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:548
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:540
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:552
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:539
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:543
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:557
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:547
Definition: warp_gemm_attribute_mfma_impl.hpp:67
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:90
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:70
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:75
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:78
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:81
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:91
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:86
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:77
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:79
ext_vector_t< bf16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:73
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:68
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:87
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:71
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:85
ext_vector_t< bf16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:74
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:89
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:69
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:96
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:92
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:115
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:82
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:84
Definition: warp_gemm_attribute_mfma_impl.hpp:920
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:943
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:949
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:924
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:921
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:923
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:934
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:935
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:938
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:945
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:992
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:939
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:944
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:937
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:930
ext_vector_t< bf16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:926
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:931
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:942
ext_vector_t< bf16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:927
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:922
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:940
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:928
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:932
Definition: warp_gemm_attribute_mfma_impl.hpp:448
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:455
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:451
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:452
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:463
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:456
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:454
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:508
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:459
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:472
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:458
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:471
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:462
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:467
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:466
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:465
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:460
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:470
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:450
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:449
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:477
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:473
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:468
Definition: warp_gemm_attribute_mfma_impl.hpp:625
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:646
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:627
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:649
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:648
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:650
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:686
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:628
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:637
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:626
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:644
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:655
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:631
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:636
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:635
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:645
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:640
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:632
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:651
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:633
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:643
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:639
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:629
Definition: warp_gemm_attribute_mfma_impl.hpp:715
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:718
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:738
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:733
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:776
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:716
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:739
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:726
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:717
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:729
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:740
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:730
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:734
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:727
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:719
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:735
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:741
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:736
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:721
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:745
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:722
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:725
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:723
Definition: warp_gemm_attribute_mfma_impl.hpp:193
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:194
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:210
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:197
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:216
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:207
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:196
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:215
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:218
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:199
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:208
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:204
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:195
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:212
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:211
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:201
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:217
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:213
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:241
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:222
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:203
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:205
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:200
Definition: warp_gemm_attribute_mfma_impl.hpp:256
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:267
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:279
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:281
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:258
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:278
ext_vector_t< fp16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:262
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:275
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:276
ext_vector_t< fp16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:263
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:266
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:264
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:304
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:259
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:260
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:273
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:280
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:271
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:270
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:257
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:274
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:268
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:285
Definition: warp_gemm_attribute_mfma_impl.hpp:806
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:807
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:808
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:820
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:828
ext_vector_t< fp16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:812
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:829
ext_vector_t< fp16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:813
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:816
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:810
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:835
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:817
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:826
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:878
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:825
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:824
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:831
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:821
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:818
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:823
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:830
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:814
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:809
Definition: warp_gemm_attribute_mfma_impl.hpp:130
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:153
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:147
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:154
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:152
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:159
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:155
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:134
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:142
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:133
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:148
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:138
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:145
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:131
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:140
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:137
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:141
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:178
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:144
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:136
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:132
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:149
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:150
Definition: warp_gemm_attribute_mfma_impl.hpp:319
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:330
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:342
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:321
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:325
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:368
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:326
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:334
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:323
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:349
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:327
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:337
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:338
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:343
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:322
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:331
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:339
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:329
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:344
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:320
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:340
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:345
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:333
Definition: warp_gemm_attribute_mfma_impl.hpp:383
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:404
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:387
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:393
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:401
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:386
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:384
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:398
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:413
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:395
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:397
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:391
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:394
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:432
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:406
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:389
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:402
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:407
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:408
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:403
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:385
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:390
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:409
Definition: integral_constant.hpp:13
Definition: functional.hpp:43
Definition: debug.hpp:67
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_)
Definition: warp_gemm_attribute_mfma_impl.hpp:25
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_)
Definition: warp_gemm_attribute_mfma_impl.hpp:42