include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File

include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File#

Composable Kernel: include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File
warp_gemm_attribute_mfma_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 // TODO: refactor warp-gemm
11 // currently there is a discrepency for vav/vva if we need transpose C/D
12 // e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
13 // because we swap the A/B pointer in _impl code (but not known this info here)
14 enum class WGAttrCtlEnum
15 {
16  Default_ = 0,
17  Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
18  Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
19  Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
20  Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
21  Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
22  // raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
23 };
24 
25 #define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
26  if constexpr(post_nop_) \
27  { \
28  asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
29  "s_nop 3" \
30  : dmod_(c_vec) \
31  : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
32  :); \
33  } \
34  else \
35  { \
36  asm volatile(mfma_ " %0, %1, %2, %3\n" \
37  : dmod_(c_vec) \
38  : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
39  :); \
40  }
41 
42 #define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
43  if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
44  { \
45  DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
46  } \
47  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
48  { \
49  DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
50  } \
51  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
52  { \
53  DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
54  } \
55  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
56  { \
57  DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
58  } \
59  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
60  { \
61  DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
62  }
63 
64 // FP16
65 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
67 {
68  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
69  using ADataType = fp16_t;
70  using BDataType = fp16_t;
71  using CDataType = float;
72 
76 
77  static constexpr index_t kM = 32;
78  static constexpr index_t kN = 32;
79  static constexpr index_t kK = 8;
80 
81  static constexpr index_t kAMBlock = 1;
82  static constexpr index_t kBNBlock = 1;
83 
84  static constexpr index_t kAMLane = 32;
85  static constexpr index_t kBNLane = 32;
86  static constexpr index_t kABKLane = 2;
87  static constexpr index_t kABKPerLane = 4;
88 
89  static constexpr index_t kCMLane = 2;
90  static constexpr index_t kCNLane = 32;
91  static constexpr index_t kCM0PerLane = 4;
92  static constexpr index_t kCM1PerLane = 4;
93 
94  // c_vec += a_vec * b_vec
95  template <bool post_nop_ = false>
97  const AVecType& a_vec,
98  const BVecType& b_vec,
99  bool_constant<post_nop_> = {}) const
100  {
101  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
102  else
103  {
104 #if defined(__gfx9__)
105  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
106 #else
107  ck_tile::ignore = c_vec;
108  ck_tile::ignore = a_vec;
109  ck_tile::ignore = b_vec;
110 #endif
111  }
112  }
113 
114  // c_vec = a_vec * b_vec
115  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
116  {
117 #if defined(__gfx9__)
118  return bit_cast<CVecType>(
119  __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
120 #else
121  ck_tile::ignore = a_vec;
122  ck_tile::ignore = b_vec;
123  return CVecType{0.f};
124 #endif
125  }
126 };
127 
128 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
130 {
131  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
132  using ADataType = fp16_t;
133  using BDataType = fp16_t;
134  using CDataType = float;
135 
139 
140  static constexpr index_t kM = 16;
141  static constexpr index_t kN = 16;
142  static constexpr index_t kK = 16;
143 
144  static constexpr index_t kAMBlock = 1;
145  static constexpr index_t kBNBlock = 1;
146 
147  static constexpr index_t kAMLane = 16;
148  static constexpr index_t kBNLane = 16;
149  static constexpr index_t kABKLane = 4;
150  static constexpr index_t kABKPerLane = 4;
151 
152  static constexpr index_t kCMLane = 4;
153  static constexpr index_t kCNLane = 16;
154  static constexpr index_t kCM0PerLane = 1;
155  static constexpr index_t kCM1PerLane = 4;
156 
157  // c_vec += a_vec * b_vec
158  template <bool post_nop_ = false>
160  const AVecType& a_vec,
161  const BVecType& b_vec,
162  bool_constant<post_nop_> = {}) const
163  {
164  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
165  else
166  {
167 #if defined(__gfx9__)
168  c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
169 #else
170  ck_tile::ignore = c_vec;
171  ck_tile::ignore = a_vec;
172  ck_tile::ignore = b_vec;
173 #endif
174  }
175  }
176 
177  // c_vec = a_vec * b_vec
178  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
179  {
180 #if defined(__gfx9__)
181  return bit_cast<CVecType>(
182  __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
183 #else
184  ck_tile::ignore = a_vec;
185  ck_tile::ignore = b_vec;
186  return CVecType{0.f};
187 #endif
188  }
189 };
190 
191 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
193 {
194  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
195  using ADataType = fp16_t;
196  using BDataType = fp16_t;
197  using CDataType = float;
198 
202 
203  static constexpr index_t kM = 4;
204  static constexpr index_t kN = 64;
205  static constexpr index_t kK = 4;
206 
207  static constexpr index_t kAMBlock = 1;
208  static constexpr index_t kBNBlock = 16;
209 
210  // we only write down single block (4 threads) thread mapping here
211  static constexpr index_t kAMLane = 4;
212  static constexpr index_t kBNLane = 4;
213  static constexpr index_t kABKLane = 1;
214  static constexpr index_t kABKPerLane = 4;
215 
216  static constexpr index_t kCMLane = 1;
217  static constexpr index_t kCNLane = 4;
218  static constexpr index_t kCM0PerLane = 1;
219  static constexpr index_t kCM1PerLane = 4;
220 
221  // c_vec += a_vec * b_vec
222  template <bool post_nop_ = false>
224  const AVecType& a_vec,
225  const BVecType& b_vec,
226  bool_constant<post_nop_> = {}) const
227  {
228  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
229  else
230  {
231 #if defined(__gfx9__)
232  c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
233 #else
234  ignore = c_vec;
235  ignore = a_vec;
236  ignore = b_vec;
237 #endif
238  }
239  }
240 
241  // c_vec = a_vec * b_vec
242  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
243  {
244 #if defined(__gfx9__)
245  return bit_cast<CVecType>(
246  __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
247 #else
248  ignore = a_vec;
249  ignore = b_vec;
250  return CVecType{0.f};
251 #endif
252  }
253 };
254 
255 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
257 {
258  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
259  using ADataType = fp16_t;
260  using BDataType = fp16_t;
261  using CDataType = float;
262 
266 
267  static constexpr index_t kM = 64;
268  static constexpr index_t kN = 4;
269  static constexpr index_t kK = 4;
270 
271  static constexpr index_t kAMBlock = 16;
272  static constexpr index_t kBNBlock = 1;
273 
274  // we only write down single block (4 threads) thread mapping here
275  static constexpr index_t kAMLane = 4;
276  static constexpr index_t kBNLane = 4;
277  static constexpr index_t kABKLane = 1;
278  static constexpr index_t kABKPerLane = 4;
279 
280  static constexpr index_t kCMLane = 1;
281  static constexpr index_t kCNLane = 4;
282  static constexpr index_t kCM0PerLane = 1;
283  static constexpr index_t kCM1PerLane = 4;
284 
285  // c_vec += a_vec * b_vec
286  template <bool post_nop_ = false>
288  const AVecType& a_vec,
289  const BVecType& b_vec,
290  bool_constant<post_nop_> = {}) const
291  {
292  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
293  else
294  {
295 #if defined(__gfx9__)
296  c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
297 #else
298  ignore = c_vec;
299  ignore = a_vec;
300  ignore = b_vec;
301 #endif
302  }
303  }
304 
305  // c_vec = a_vec * b_vec
306  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
307  {
308 #if defined(__gfx9__)
309  return bit_cast<CVecType>(
310  __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
311 #else
312  ignore = a_vec;
313  ignore = b_vec;
314  return CVecType{0.f};
315 #endif
316  }
317 };
318 
319 // Bf16
320 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
322 {
323  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
324  using ADataType = bf16_t;
325  using BDataType = bf16_t;
326  using CDataType = float;
327 
331 
332  static constexpr index_t kM = 32;
333  static constexpr index_t kN = 32;
334  static constexpr index_t kK = 8;
335 
336  static constexpr index_t kAMBlock = 1;
337  static constexpr index_t kBNBlock = 1;
338 
339  static constexpr index_t kAMLane = 32;
340  static constexpr index_t kBNLane = 32;
341  static constexpr index_t kABKLane = 2;
342  static constexpr index_t kABKPerLane = 4;
343 
344  static constexpr index_t kCMLane = 2;
345  static constexpr index_t kCNLane = 32;
346  static constexpr index_t kCM0PerLane = 4;
347  static constexpr index_t kCM1PerLane = 4;
348 
349  // c_vec += a_vec * b_vec
350  template <bool post_nop_ = false>
352  const AVecType& a_vec,
353  const BVecType& b_vec,
354  bool_constant<post_nop_> = {}) const
355  {
356  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
357  else
358  {
359 #if defined(__gfx90a__) || defined(__gfx94__)
360  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
361 #elif defined(__gfx908__)
362  static_for<0, 2, 1>{}([&](auto k) {
363  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
364  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
365  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
366  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
367  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
368  c_vec,
369  0,
370  0,
371  0);
372  });
373 #else
374  ck_tile::ignore = c_vec;
375  ck_tile::ignore = a_vec;
376  ck_tile::ignore = b_vec;
377 #endif
378  }
379  }
380 
381  // c_vec = a_vec * b_vec
382  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
383  {
384 #if defined(__gfx90a__) || defined(__gfx94__)
385  return bit_cast<CVecType>(
386  __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
387 #elif defined(__gfx908__)
388  CVecType c_vec{0.f};
389  static_for<0, 2, 1>{}([&](auto k) {
390  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
391  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
392  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
393  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
394  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
395  c_vec,
396  0,
397  0,
398  0);
399  });
400  return c_vec;
401 #else
402  ck_tile::ignore = a_vec;
403  ck_tile::ignore = b_vec;
404  return CVecType{0.f};
405 #endif
406  }
407 };
408 
409 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
411 {
412  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
413  using ADataType = bf16_t;
414  using BDataType = bf16_t;
415  using CDataType = float;
416 
420 
421  static constexpr index_t kM = 16;
422  static constexpr index_t kN = 16;
423  static constexpr index_t kK = 16;
424 
425  static constexpr index_t kAMBlock = 1;
426  static constexpr index_t kBNBlock = 1;
427 
428  static constexpr index_t kAMLane = 16;
429  static constexpr index_t kBNLane = 16;
430  static constexpr index_t kABKLane = 4;
431  static constexpr index_t kABKPerLane = 4;
432 
433  static constexpr index_t kCMLane = 4;
434  static constexpr index_t kCNLane = 16;
435  static constexpr index_t kCM0PerLane = 1;
436  static constexpr index_t kCM1PerLane = 4;
437 
438  // c_vec += a_vec * b_vec
439  template <bool post_nop_ = false>
441  const AVecType& a_vec,
442  const BVecType& b_vec,
443  bool_constant<post_nop_> = {}) const
444  {
445  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
446  {
447 #if defined(__gfx90a__) || defined(__gfx94__)
448  c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
449 #elif defined(__gfx908__)
450  static_for<0, 2, 1>{}([&](auto k) {
451  c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
452  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
453  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
454  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
455  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
456  c_vec,
457  0,
458  0,
459  0);
460  });
461 #else
462  ck_tile::ignore = c_vec;
463  ck_tile::ignore = a_vec;
464  ck_tile::ignore = b_vec;
465 #endif
466  }
467  }
468 
469  // c_vec = a_vec * b_vec
470  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
471  {
472 #if defined(__gfx90a__) || defined(__gfx94__)
473  return bit_cast<CVecType>(
474  __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
475 #elif defined(__gfx908__)
476  CVecType c_vec{0.f};
477  static_for<0, 2, 1>{}([&](auto k) {
478  c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
479  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
480  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
481  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
482  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
483  c_vec,
484  0,
485  0,
486  0);
487  });
488  return c_vec;
489 #else
490  ck_tile::ignore = a_vec;
491  ck_tile::ignore = b_vec;
492  return CVecType{0.f};
493 #endif
494  }
495 };
496 
497 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
499 {
500  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
501  using ADataType = bf16_t;
502  using BDataType = bf16_t;
503  using CDataType = float;
504 
508 
509  static constexpr index_t kM = 4;
510  static constexpr index_t kN = 64;
511  static constexpr index_t kK = 4;
512 
513  static constexpr index_t kAMBlock = 1;
514  static constexpr index_t kBNBlock = 16;
515 
516  // we only write down single block (4 threads) thread mapping here
517  static constexpr index_t kAMLane = 4;
518  static constexpr index_t kBNLane = 4;
519  static constexpr index_t kABKLane = 1;
520  static constexpr index_t kABKPerLane = 4;
521 
522  static constexpr index_t kCMLane = 1;
523  static constexpr index_t kCNLane = 4;
524  static constexpr index_t kCM0PerLane = 1;
525  static constexpr index_t kCM1PerLane = 4;
526 
527  // c_vec += a_vec * b_vec
528  template <bool post_nop_ = false>
530  const AVecType& a_vec,
531  const BVecType& b_vec,
532  bool_constant<post_nop_> = {}) const
533  {
534  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
535  else
536  {
537 #if defined(__gfx9__)
538  c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
539 #else
540  ignore = c_vec;
541  ignore = a_vec;
542  ignore = b_vec;
543 #endif
544  }
545  }
546 
547  // c_vec = a_vec * b_vec
548  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
549  {
550 #if defined(__gfx9__)
551  return bit_cast<CVecType>(
552  __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
553 #else
554  ignore = a_vec;
555  ignore = b_vec;
556  return CVecType{0.f};
557 #endif
558  }
559 };
560 
561 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
563 {
564  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
565  using ADataType = bf16_t;
566  using BDataType = bf16_t;
567  using CDataType = float;
568 
572 
573  static constexpr index_t kM = 64;
574  static constexpr index_t kN = 4;
575  static constexpr index_t kK = 4;
576 
577  static constexpr index_t kAMBlock = 16;
578  static constexpr index_t kBNBlock = 1;
579 
580  // we only write down single block (4 threads) thread mapping here
581  static constexpr index_t kAMLane = 4;
582  static constexpr index_t kBNLane = 4;
583  static constexpr index_t kABKLane = 1;
584  static constexpr index_t kABKPerLane = 4;
585 
586  static constexpr index_t kCMLane = 1;
587  static constexpr index_t kCNLane = 4;
588  static constexpr index_t kCM0PerLane = 1;
589  static constexpr index_t kCM1PerLane = 4;
590 
591  // c_vec += a_vec * b_vec
592  template <bool post_nop_ = false>
594  const AVecType& a_vec,
595  const BVecType& b_vec,
596  bool_constant<post_nop_> = {}) const
597  {
598  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
599  else
600  {
601 #if defined(__gfx9__)
602  c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
603 #else
604  ignore = c_vec;
605  ignore = a_vec;
606  ignore = b_vec;
607 #endif
608  }
609  }
610 
611  // c_vec = a_vec * b_vec
612  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
613  {
614 #if defined(__gfx9__)
615  return bit_cast<CVecType>(
616  __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
617 #else
618  ignore = a_vec;
619  ignore = b_vec;
620  return CVecType{0.f};
621 #endif
622  }
623 };
624 
625 // FP8
626 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
628 {
629  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
630  using ADataType = AType_;
631  using BDataType = BType_;
632  using CDataType = float;
633 
637 
638  static constexpr index_t kM = 32;
639  static constexpr index_t kN = 32;
640  static constexpr index_t kK = 16;
641 
642  static constexpr index_t kAMBlock = 1;
643  static constexpr index_t kBNBlock = 1;
644 
645  static constexpr index_t kAMLane = 32;
646  static constexpr index_t kBNLane = 32;
647  static constexpr index_t kABKLane = 2;
648  static constexpr index_t kABKPerLane = 8;
649 
650  static constexpr index_t kCMLane = 2;
651  static constexpr index_t kCNLane = 32;
652  static constexpr index_t kCM0PerLane = 4;
653  static constexpr index_t kCM1PerLane = 4;
654 
655  // c_vec += a_vec * b_vec
656  template <bool post_nop_ = false>
658  const AVecType& a_vec,
659  const BVecType& b_vec,
660  bool_constant<post_nop_> = {}) const
661  {
662  if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
663  {
664  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
665  {
666  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "v", "v")
667  }
668  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
669  {
670  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "v", "v")
671  }
672  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
673  {
674  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "v", "v")
675  }
676  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
677  {
678  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "v", "v")
679  }
680  }
681  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
682  {
683  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
684  {
685  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "a", "v")
686  }
687  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
688  {
689  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "a", "v")
690  }
691  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
692  {
693  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "a", "v")
694  }
695  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
696  {
697  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "a", "v")
698  }
699  }
700  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
701  {
702  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
703  {
704  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "v", "v")
705  }
706  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
707  {
708  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "v", "v")
709  }
710  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
711  {
712  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "v", "v")
713  }
714  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
715  {
716  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "v", "v")
717  }
718  }
719  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
720  {
721  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
722  {
723  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "a", "v")
724  }
725  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
726  {
727  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "a", "v")
728  }
729  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
730  {
731  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "a", "v")
732  }
733  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
734  {
735  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "a", "v")
736  }
737  }
738  else
739  {
740 #if defined(__gfx94__)
741  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
742  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
743  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
744  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
745  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
746  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
747  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
748  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
749  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
750  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
751  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
752  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
753 #elif defined(__gfx908__) || defined(__gfx90a__)
754  static_for<0, 8, 1>{}([&](auto k) {
755  float a_f32 =
756  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
757  .template get_as<ADataType>()[number<k>{}]);
758  float b_f32 =
759  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
760  .template get_as<BDataType>()[number<k>{}]);
761 
762  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
763  });
764 #else
765  ck_tile::ignore = c_vec;
766  ck_tile::ignore = a_vec;
767  ck_tile::ignore = b_vec;
768 #endif
769  }
770  }
771 
772  // c_vec = a_vec * b_vec
773  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
774  {
775 #if defined(__gfx94__)
776  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
777  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
778  bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
779  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
780  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
781  bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
782  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
783  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
784  bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
785  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
786  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
787  bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
788 #elif defined(__gfx908__) || defined(__gfx90a__)
789  CVecType c_vec{0.f};
790  static_for<0, 8, 1>{}([&](auto k) {
791  float a_f32 =
792  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
793  .template get_as<ADataType>()[number<k>{}]);
794  float b_f32 =
795  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
796  .template get_as<BDataType>()[number<k>{}]);
797 
798  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
799  });
800  return c_vec;
801 #else
802  ck_tile::ignore = a_vec;
803  ck_tile::ignore = b_vec;
804  return CVecType{0.f};
805 #endif
806  }
807 };
808 
809 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
812 
813 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
816 
817 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
820 
821 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
824 
825 // int8
826 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
828 {
829  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
830  using ADataType = int8_t;
831  using BDataType = int8_t;
832  using CDataType = int32_t;
833 
837 
838  static constexpr index_t kM = 32;
839  static constexpr index_t kN = 32;
840  static constexpr index_t kK = 16;
841 
842  static constexpr index_t kAMBlock = 1;
843  static constexpr index_t kBNBlock = 1;
844 
845  static constexpr index_t kAMLane = 32;
846  static constexpr index_t kBNLane = 32;
847  static constexpr index_t kABKLane = 2;
848  static constexpr index_t kABKPerLane = 8;
849 
850  static constexpr index_t kCMLane = 2;
851  static constexpr index_t kCNLane = 32;
852  static constexpr index_t kCM0PerLane = 4;
853  static constexpr index_t kCM1PerLane = 4;
854 
855  // c_vec += a_vec * b_vec
856  template <bool post_nop_ = false>
858  const AVecType& a_vec,
859  const BVecType& b_vec,
860  bool_constant<post_nop_> = {}) const
861  {
862  DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
863  else
864  {
865 #if defined(__gfx94__)
866  c_vec = __builtin_amdgcn_mfma_i32_32x32x8i8(
867  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
868 #elif defined(__gfx908__) || defined(__gfx90a__)
869  static_for<0, 8, 1>{}([&](auto k) {
870  float a_f32 =
871  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
872  .template get_as<ADataType>()[number<k>{}]);
873  float b_f32 =
874  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
875  .template get_as<BDataType>()[number<k>{}]);
876 
877  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
878  });
879 #else
880  ck_tile::ignore = c_vec;
881  ck_tile::ignore = a_vec;
882  ck_tile::ignore = b_vec;
883 #endif
884  }
885  }
886 
887  // c_vec = a_vec * b_vec
888  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
889  {
890  CVecType c_vec{0};
891  operator()(c_vec, a_vec, b_vec);
892  return c_vec;
893  }
894 };
895 
896 #undef DISPATCH_MFMA_
897 
898 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
WGAttrCtlEnum
Definition: warp_gemm_attribute_mfma_impl.hpp:15
_Float16 fp16_t
Definition: half.hpp:110
tuple_array< T, N > thread_buffer
Definition: thread_buffer.hpp:14
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
int32_t index_t
Definition: integer.hpp:9
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:54
float fp32x16_t
Definition: vector_type.hpp:89
float fp32x4_t
Definition: vector_type.hpp:87
Definition: warp_gemm_attribute_mfma_impl.hpp:628
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:650
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:629
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:634
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:631
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:653
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:636
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:648
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:632
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:642
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:651
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:657
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:773
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:645
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:638
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:630
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:635
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:640
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:639
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:652
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:643
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:647
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:646
Definition: warp_gemm_attribute_mfma_impl.hpp:828
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:830
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:857
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:850
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:834
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:835
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:852
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:832
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:829
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:888
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:831
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:845
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:839
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:843
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:840
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:853
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:836
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:851
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:848
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:842
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:846
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:847
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:838
Definition: warp_gemm_attribute_mfma_impl.hpp:411
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:419
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:429
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:418
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:440
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:433
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:470
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:423
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:428
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:412
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:434
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:435
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:430
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:425
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:415
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:436
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:422
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:414
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:426
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:413
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:417
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:431
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:421
Definition: warp_gemm_attribute_mfma_impl.hpp:322
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:329
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:325
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:326
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:337
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:330
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:328
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:382
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:333
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:346
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:332
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:345
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:336
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:341
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:340
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:339
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:334
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:344
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:324
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:323
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:351
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:347
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:342
Definition: warp_gemm_attribute_mfma_impl.hpp:499
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:520
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:501
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:523
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:522
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:524
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:548
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:502
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:511
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:500
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:518
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:529
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:505
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:510
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:509
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:519
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:514
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:506
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:525
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:507
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:517
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:513
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:503
Definition: warp_gemm_attribute_mfma_impl.hpp:563
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:566
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:586
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:581
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:612
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:564
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:587
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:574
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:565
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:577
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:588
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:578
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:582
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:575
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:567
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:583
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:589
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:584
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:569
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:593
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:570
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:573
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:571
Definition: warp_gemm_attribute_mfma_impl.hpp:130
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:131
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:147
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:134
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:153
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:144
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:133
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:152
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:155
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:136
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:145
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:141
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:132
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:149
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:148
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:138
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:154
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:150
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:178
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:159
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:140
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:142
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:137
Definition: warp_gemm_attribute_mfma_impl.hpp:67
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:90
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:84
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:91
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:89
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:96
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:92
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:71
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:79
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:70
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:85
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:75
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:82
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:68
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:77
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:74
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:78
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:115
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:81
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:73
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:69
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:86
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:87
Definition: warp_gemm_attribute_mfma_impl.hpp:193
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:204
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:216
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:195
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:199
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:242
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:200
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:208
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:197
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:223
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:201
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:211
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:212
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:217
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:196
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:205
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:213
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:203
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:218
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:194
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:214
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:219
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:207
Definition: warp_gemm_attribute_mfma_impl.hpp:257
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:278
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:261
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:267
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:275
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:260
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:258
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:272
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:287
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:269
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:271
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:265
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:268
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:306
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:280
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:263
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:276
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:281
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:282
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:277
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:259
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:264
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:283
Definition: integral_constant.hpp:13
Definition: functional.hpp:43
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_)
Definition: warp_gemm_attribute_mfma_impl.hpp:25
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_)
Definition: warp_gemm_attribute_mfma_impl.hpp:42