/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File
warp_gemm_attribute_mfma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // Number of groups of consecutive elements to fill in a ABKLane
13 {
14  Single = 1,
15  Double = 2,
16  Quad = 4,
17  Invalid = -1
18 };
19 
20 template <typename WarpGemmAttributeMfmaImpl_,
23 {
25  static constexpr auto AttrNumAccess = AttrNumAccess_;
26  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
27 
28  using ADataType = typename Impl::ADataType;
29  using BDataType = typename Impl::BDataType;
30  using CDataType = typename Impl::CDataType;
31 
32  using AVecType = typename Impl::AVecType;
33  using BVecType = typename Impl::BVecType;
34  using CVecType = typename Impl::CVecType;
35 
36  static constexpr index_t kM = Impl::kM;
37  static constexpr index_t kN = Impl::kN;
38  static constexpr index_t kK = Impl::kK;
39  static constexpr index_t kKPerThread = Impl::kABKPerLane;
40  static constexpr index_t kCMLane = Impl::kCMLane;
41 
42  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
43 
44  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
45  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
46 
47  template <index_t kMNLane>
48  static constexpr auto get_warp_dstr_encoding()
49  {
50  if constexpr(AttrNumAccessV == 1)
51  {
53  sequence<>,
58  sequence<1>>{};
59  }
60  else
61  {
62  static_assert(kKPerThread % AttrNumAccessV == 0,
63  "kKPerThread must be divisible by NumAccess");
65  sequence<>,
67  sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
71  sequence<0, 2>>{};
72  }
73  }
74  using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
75  using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
76 
78  sequence<>,
85 
86  // c_vec += a_vec * b_vec
87  template <bool post_nop_ = false>
89  const AVecType& a_vec,
90  const BVecType& b_vec,
91  bool_constant<post_nop_> = {}) const
92  {
93  Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
94  }
95 
96  // c_vec += a_vec * b_vec
97  template <index_t opselA, index_t opselB, bool post_nop_ = false>
99  const AVecType& a_vec,
100  const int32_t& a_scale,
101  const BVecType& b_vec,
102  const int32_t& b_scale,
103  bool_constant<post_nop_> = {}) const
104  {
105  Impl{}.template operator()<opselA, opselB>(
106  c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant<post_nop_>{});
107  }
108 
109  // c_vec = a_vec * b_vec
110  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
111  {
112  return Impl{}(a_vec, b_vec);
113  }
114 
115  // c_vec = a_vec * b_vec
116  template <index_t opselA, index_t opselB>
118  const int32_t& a_scale,
119  const BVecType& b_vec,
120  const int32_t& b_scale) const
121  {
122  auto c_vec = Impl{}.template operator()<opselA, opselB>(a_vec, a_scale, b_vec, b_scale);
123  }
124 };
125 
126 template <typename WarpGemmAttributeMfmaImpl_,
127  index_t kKIter,
130 {
131  static_assert(kKIter > 0, "wrong!");
132 
134  static constexpr auto AttrNumAccess = AttrNumAccess_;
135  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
136 
137  using ADataType = typename Impl::ADataType;
138  using BDataType = typename Impl::BDataType;
139  using CDataType = typename Impl::CDataType;
140 
141  using AVecType =
143  using BVecType =
145  using CVecType = typename Impl::CVecType;
146 
147  static constexpr index_t kM = Impl::kM;
148  static constexpr index_t kN = Impl::kN;
149  static constexpr index_t kK = Impl::kK * kKIter;
150  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
151  static constexpr index_t kCMLane = Impl::kCMLane;
152 
153  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
154 
155  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
156  "Multi-block on both M & N directions is not supported");
157 
158  CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
159  {
160  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
161  {
162  if constexpr(AttrNumAccessV == 1)
163  {
165  sequence<>,
170  sequence<2>,
171  sequence<1>>{};
172  }
173  else
174  {
175  static_assert(kKPerThread % AttrNumAccessV == 0,
176  "kKPerThread must be divisible by NumAccess");
178  sequence<>,
181  Impl::kABKLane,
182  Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
186  sequence<0, 2>>{};
187  }
188  }
189  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
190  {
191  static_assert(AttrNumAccessV == 1,
192  "Multiple access is not supported when using multi-block");
193  // each M blocks share the same data
200  sequence<2>,
201  sequence<1>>{};
202  }
203  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
204  {
205  static_assert(AttrNumAccessV == 1,
206  "Multiple access is not supported when using multi-block");
207  // single block to multi-block thread mapping
209  sequence<>,
214  sequence<2>,
215  sequence<1>>{};
216  }
217  }
218 
219  CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
220  {
221  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
222  {
223  if constexpr(AttrNumAccessV == 1)
224  {
226  sequence<>,
231  sequence<2>,
232  sequence<1>>{};
233  }
234  else
235  {
236 
237  static_assert(kKPerThread % AttrNumAccessV == 0,
238  "kKPerThread must be divisible by NumAccess");
240  sequence<>,
243  Impl::kABKLane,
244  Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
248  sequence<0, 2>>{};
249  }
250  }
251  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
252  {
253  static_assert(AttrNumAccessV == 1,
254  "Multiple access is not supported when using multi-block");
255  // single block to multi-block thread mapping
257  sequence<>,
262  sequence<2>,
263  sequence<1>>{};
264  }
265  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
266  {
267  static_assert(AttrNumAccessV == 1,
268  "Multiple access is not supported when using multi-block");
269  // each N blocks share the same data
276  sequence<2>,
277  sequence<1>>{};
278  }
279  }
280 
281  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
282  {
283  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
284  {
286  sequence<>,
292  sequence<0, 2>>{};
293  }
294  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
295  {
297  sequence<>,
303  sequence<0, 2>>{};
304  }
305  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
306  {
308  sequence<>,
309  tuple<
315  sequence<0, 2>>{};
316  }
317  }
318 
320 
322 
324 
325  // c_vec += a_vec * b_vec
326  template <bool post_nop_ = false>
328  const AVecType& a_vec,
329  const BVecType& b_vec,
330  bool_constant<post_nop_> = {}) const
331  {
334 
335  static_for<0, kKIter, 1>{}([&](auto iKIter) {
336  Impl{}(c_vec,
337  reinterpret_cast<const buf_a&>(a_vec)
338  .template get_as<typename Impl::AVecType>()[iKIter],
339  reinterpret_cast<const buf_b&>(b_vec)
340  .template get_as<typename Impl::BVecType>()[iKIter],
341  bool_constant<post_nop_>{});
342  });
343  }
344 
345  template <index_t iKIter, bool post_nop_ = false>
347  const AVecType& a_vec,
348  const BVecType& b_vec,
350  bool_constant<post_nop_> = {}) const
351  {
354 
355  static_assert(iKIter < kKIter);
356 
357  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
358  Impl{}(c_vec,
359  reinterpret_cast<const buf_a&>(a_vec)
360  .template get_as<typename Impl::AVecType>()[iKIter],
361  reinterpret_cast<const buf_b&>(b_vec)
362  .template get_as<typename Impl::BVecType>()[iKIter],
363  bool_constant<post_nop_>{});
364  //});
365  }
366 
367  // c_vec = a_vec * b_vec
368  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
369  {
370  constexpr auto I0 = number<0>{};
373 
374  // c = a * b
375  auto c_vec = Impl{}(
376  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
377  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
378 
379  // c += a * b
380  static_for<1, kKIter, 1>{}([&](auto iKIter) {
381  Impl{}(c_vec,
382  reinterpret_cast<const buf_a&>(a_vec)
383  .template get_as<typename Impl::AVecType>()[iKIter],
384  reinterpret_cast<const buf_b&>(b_vec)
385  .template get_as<typename Impl::BVecType>()[iKIter]);
386  });
387 
388  return c_vec;
389  }
390 };
391 
392 template <typename WarpGemmAttributeMfmaImpl_,
395 {
397  static constexpr auto AttrNumAccess = AttrNumAccess_;
398  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
399 
400  using ADataType = typename Impl::BDataType;
401  using BDataType = typename Impl::ADataType;
402  using CDataType = typename Impl::CDataType;
403 
404  using AVecType = typename Impl::BVecType;
405  using BVecType = typename Impl::AVecType;
406  using CVecType = typename Impl::CVecType;
407 
408  static constexpr index_t kM = Impl::kN;
409  static constexpr index_t kN = Impl::kM;
410  static constexpr index_t kK = Impl::kK;
411  static constexpr index_t kKPerThread = Impl::kABKPerLane;
412  static constexpr index_t kCMLane = Impl::kCMLane;
413 
414  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
415 
416  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
417  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
418 
419  template <index_t kMNLane>
420  static constexpr auto get_warp_dstr_encoding()
421  {
422  if constexpr(AttrNumAccessV == 1)
423  {
425  sequence<>,
429  sequence<2>,
430  sequence<1>>{};
431  }
432  else
433  {
434  static_assert(kKPerThread % AttrNumAccessV == 0,
435  "kKPerThread must be divisible by NumAccess");
437  sequence<>,
439  sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
443  sequence<0, 2>>{};
444  }
445  }
446  using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
447  using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
448 
450  sequence<>,
457 
458  // c_vec += a_vec * b_vec
459  template <bool post_nop_ = false>
461  const AVecType& a_vec,
462  const BVecType& b_vec,
463  bool_constant<post_nop_> = {}) const
464  {
465  // swap A and B
466  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
467  }
468 
469  // c_vec = a_vec * b_vec
470  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
471  {
472  // swap A and B
473  return Impl{}(b_vec, a_vec);
474  }
475 };
476 
477 template <typename WarpGemmAttributeMfmaImpl_, index_t SFactor_ = 2>
479 {
481 
482  using ADataType = typename Impl::BDataType;
483  using BDataType = typename Impl::ADataType;
484  using CDataType = typename Impl::CDataType;
485 
486  using AVecType = typename Impl::BVecType;
487  using BVecType = typename Impl::AVecType;
488  using CVecType = typename Impl::CVecType;
489 
490  static constexpr index_t kM = Impl::kN;
491  static constexpr index_t kN = Impl::kM;
492  static constexpr index_t kK = Impl::kK;
493  static constexpr index_t kKPerThread = Impl::kABKPerLane;
494  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
495 
496  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
497 
498  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
499  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
500 
502  sequence<>,
506  sequence<2>,
507  sequence<1>>;
508 #if 0
510  sequence<>,
511  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
512  Impl::kABKLane,
513  2,
514  Impl::kABKPerLane>,
518  sequence<2>,
519  sequence<1>>;
520 
522  sequence<>,
524  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
529 #else
530  // TODO: more test not only 32x32
532  sequence<>,
533  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
534  Impl::kCMLane,
535  SFactor,
536  Impl::kCM1PerLane>,
540  sequence<2>,
541  sequence<1>>;
542 
544  sequence<>,
546  sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
551 #endif
552  template <bool post_nop_ = false>
553  // c_vec += a_vec * b_vec
555  const AVecType& a_vec,
556  const BVecType& b_vec,
557  bool_constant<post_nop_> = {}) const
558  {
559  // swap A and B
560  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
561  }
562 
563  // c_vec = a_vec * b_vec
564  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
565  {
566  // swap A and B
567  return Impl{}(b_vec, a_vec);
568  }
569 };
570 
571 template <typename WarpGemmAttributeMfmaImpl_,
572  index_t kKIter,
575 {
577  static constexpr auto AttrNumAccess = AttrNumAccess_;
578 
579  // swap A and B
580  using ADataType = typename Impl::BDataType;
581  using BDataType = typename Impl::ADataType;
582  using CDataType = typename Impl::CDataType;
583 
584  using AVecType =
586  using BVecType =
588  using CVecType = typename Impl::CVecType;
589 
590  static constexpr index_t kM = Impl::kN;
591  static constexpr index_t kN = Impl::kM;
592  static constexpr index_t kK = Impl::kK * kKIter;
593  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
594 
595  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
596 
597  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
598  "Multi-block on both M & N directions is not supported");
599 
600  CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
601  {
604  }
605 
606  CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
607  {
610  }
611 
612  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
613  {
614  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
615  {
617  sequence<>,
623  sequence<0, 2>>{};
624  }
625  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
626  {
628  sequence<>,
634  sequence<0, 2>>{};
635  }
636  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
637  {
639  sequence<>,
640  tuple<
646  sequence<0, 2>>{};
647  }
648  }
649 
651 
653 
655 
656  template <bool post_nop_ = false>
657  // c_vec += a_vec * b_vec
659  const AVecType& a_vec,
660  const BVecType& b_vec,
661  bool_constant<post_nop_> = {}) const
662  {
665  // swap A and B, value and type
666  static_for<0, kKIter, 1>{}([&](auto iKIter) {
667  Impl{}(c_vec,
668  reinterpret_cast<const buf_b&>(b_vec)
669  .template get_as<typename Impl::BVecType>()[iKIter],
670  reinterpret_cast<const buf_a&>(a_vec)
671  .template get_as<typename Impl::AVecType>()[iKIter],
672  bool_constant<post_nop_>{});
673  });
674  }
675 
676  template <index_t iKIter, bool post_nop_ = false>
677  // c_vec += a_vec * b_vec
679  const AVecType& a_vec,
680  const BVecType& b_vec,
682  bool_constant<post_nop_> = {}) const
683  {
686 
687  static_assert(iKIter < kKIter);
688  // swap A and B, value and type
689  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
690  Impl{}(c_vec,
691  reinterpret_cast<const buf_b&>(b_vec)
692  .template get_as<typename Impl::BVecType>()[iKIter],
693  reinterpret_cast<const buf_a&>(a_vec)
694  .template get_as<typename Impl::AVecType>()[iKIter],
695  bool_constant<post_nop_>{});
696  //});
697  }
698 
699  // c_vec = a_vec * b_vec
700  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
701  {
702  constexpr auto I0 = number<0>{};
705 
706  // swap A and B, value and type
707  auto c_vec = Impl{}(
708  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
709  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
710 
711  static_for<1, kKIter, 1>{}([&](auto iKIter) {
712  Impl{}(c_vec,
713  reinterpret_cast<const buf_b&>(b_vec)
714  .template get_as<typename Impl::BVecType>()[iKIter],
715  reinterpret_cast<const buf_a&>(a_vec)
716  .template get_as<typename Impl::AVecType>()[iKIter]);
717  });
718 
719  return c_vec;
720  }
721 };
722 
723 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
725 {
727 
728  // swap A and B
729  using ADataType = typename Impl::BDataType;
730  using BDataType = typename Impl::ADataType;
731  using CDataType = typename Impl::CDataType;
732 
733  using AVecType =
735  using BVecType =
737  using CVecType = typename Impl::CVecType;
738 
739  static constexpr index_t kM = Impl::kN;
740  static constexpr index_t kN = Impl::kM;
741  static constexpr index_t kK = Impl::kK * kKIter;
742  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
743  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
744 
745  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
746 
747  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
748  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
749 
751  sequence<>,
755  sequence<2>,
756  sequence<1>>;
757 #if 0
759  sequence<>,
760  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
761  Impl::kABKLane,
762  2,
763  Impl::kABKPerLane>,
767  sequence<2>,
768  sequence<1>>;
769 
771  sequence<>,
773  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
778 #else
779  // TODO: more test not only 32x32
781  sequence<>,
782  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
783  Impl::kCMLane,
784  SFactor,
785  Impl::kCM1PerLane>,
789  sequence<2>,
790  sequence<1>>;
791 
793  sequence<>,
795  sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
800 #endif
801  // c_vec += a_vec * b_vec
802  template <bool post_nop_ = false>
804  const AVecType& a_vec,
805  const BVecType& b_vec,
806  bool_constant<post_nop_> = {}) const
807  {
810  // swap A and B, value and type
811  static_for<0, kKIter, 1>{}([&](auto iKIter) {
812  Impl{}(c_vec,
813  reinterpret_cast<const buf_b&>(b_vec)
814  .template get_as<typename Impl::BVecType>()[iKIter],
815  reinterpret_cast<const buf_a&>(a_vec)
816  .template get_as<typename Impl::AVecType>()[iKIter],
817  bool_constant<post_nop_>{});
818  });
819  }
820 
821  template <index_t iKIter, bool post_nop_ = false>
823  const AVecType& a_vec,
824  const BVecType& b_vec,
826  bool_constant<post_nop_> = {}) const
827  {
830 
831  static_assert(iKIter < kKIter);
832  // swap A and B, value and type
833  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
834  Impl{}(c_vec,
835  reinterpret_cast<const buf_b&>(b_vec)
836  .template get_as<typename Impl::BVecType>()[iKIter],
837  reinterpret_cast<const buf_a&>(a_vec)
838  .template get_as<typename Impl::AVecType>()[iKIter],
839  bool_constant<post_nop_>{});
840  //});
841  }
842 
843  // c_vec = a_vec * b_vec
844  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
845  {
848  constexpr auto I0 = number<0>{};
849 
850  // swap A and B, value and type
851  auto c_vec = Impl{}(
852  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
853  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
854 
855  static_for<1, kKIter, 1>{}([&](auto iKIter) {
856  Impl{}(c_vec,
857  reinterpret_cast<const buf_b&>(b_vec)
858  .template get_as<typename Impl::BVecType>()[iKIter],
859  reinterpret_cast<const buf_a&>(a_vec)
860  .template get_as<typename Impl::AVecType>()[iKIter]);
861  });
862 
863  return c_vec;
864  }
865 };
866 
867 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
869 {
871 
872  using ADataType = typename Impl::ADataType;
873  using BDataType = typename Impl::BDataType;
874  using CDataType = typename Impl::CDataType;
875 
876  using AVecType =
878  using BVecType =
880  using CVecType = typename Impl::CVecType;
881 
882  static constexpr index_t kM = Impl::kM;
883  static constexpr index_t kN = Impl::kN;
884  static constexpr index_t kK = Impl::kK * kKIter;
885  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
886  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
887 
888  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
889 
890  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
891  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
892 
894  sequence<>,
895  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
896  Impl::kCMLane,
897  SFactor,
898  Impl::kCM1PerLane>,
902  sequence<2>,
903  sequence<1>>;
904 
906  sequence<>,
910  sequence<2>,
911  sequence<1>>;
912 
914  sequence<>,
915  tuple<sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>,
921 
922  // c_vec += a_vec * b_vec
923  template <bool post_nop_ = false>
925  const AVecType& a_vec,
926  const BVecType& b_vec,
927  bool_constant<post_nop_> = {}) const
928  {
931 
932  static_for<0, kKIter, 1>{}([&](auto iKIter) {
933  Impl{}(c_vec,
934  reinterpret_cast<const buf_a&>(a_vec)
935  .template get_as<typename Impl::AVecType>()[iKIter],
936  reinterpret_cast<const buf_b&>(b_vec)
937  .template get_as<typename Impl::BVecType>()[iKIter],
938  bool_constant<post_nop_>{});
939  });
940  }
941 
942  template <index_t iKIter, bool post_nop_ = false>
944  const AVecType& a_vec,
945  const BVecType& b_vec,
947  bool_constant<post_nop_> = {}) const
948  {
951 
952  static_assert(iKIter < kKIter);
953 
954  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
955  Impl{}(c_vec,
956  reinterpret_cast<const buf_a&>(a_vec)
957  .template get_as<typename Impl::AVecType>()[iKIter],
958  reinterpret_cast<const buf_b&>(b_vec)
959  .template get_as<typename Impl::BVecType>()[iKIter],
960  bool_constant<post_nop_>{});
961  //});
962  }
963 
964  // c_vec = a_vec * b_vec
965  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
966  {
967  constexpr auto I0 = number<0>{};
970 
971  auto c_vec = Impl{}(
972  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
973  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
974 
975  static_for<1, kKIter, 1>{}([&](auto iKIter) {
976  Impl{}(c_vec,
977  reinterpret_cast<const buf_a&>(a_vec)
978  .template get_as<typename Impl::AVecType>()[iKIter],
979  reinterpret_cast<const buf_b&>(b_vec)
980  .template get_as<typename Impl::BVecType>()[iKIter]);
981  });
982 
983  return c_vec;
984  }
985 };
986 
987 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
WGAttrNumAccessEnum
Definition: warp_gemm_attribute_mfma.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:84
int32_t int32_t
Definition: integer.hpp:10
Definition: warp_gemm_attribute_mfma.hpp:23
static constexpr auto get_warp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:48
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const int32_t &a_scale, const BVecType &b_vec, const int32_t &b_scale) const
Definition: warp_gemm_attribute_mfma.hpp:117
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:38
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:29
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma.hpp:40
typename Impl::AVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:32
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:34
decltype(get_warp_dstr_encoding< Impl::kAMLane >()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:74
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:25
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:30
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:110
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:42
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:26
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:36
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:24
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:88
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:39
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:37
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const int32_t &a_scale, const BVecType &b_vec, const int32_t &b_scale, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:98
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:28
decltype(get_warp_dstr_encoding< Impl::kBNLane >()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:75
typename Impl::BVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:33
Definition: warp_gemm_attribute_mfma.hpp:869
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:870
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:943
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:883
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:882
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:884
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:965
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:885
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:879
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:924
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:873
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:880
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:886
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:888
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:872
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:874
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:877
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:729
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:730
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:742
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:734
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:743
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:741
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:803
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:737
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:726
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:731
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:740
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:844
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:822
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:736
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:739
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:745
Definition: warp_gemm_attribute_mfma.hpp:575
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:588
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:612
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:587
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:593
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:582
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:595
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:580
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:590
static constexpr CK_TILE_DEVICE auto get_awarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:600
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:592
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:678
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:591
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:650
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:700
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:652
static constexpr CK_TILE_DEVICE auto get_bwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:606
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:654
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:576
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:577
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:585
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:658
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:581
Definition: warp_gemm_attribute_mfma.hpp:130
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:134
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma.hpp:151
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:321
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:281
static constexpr CK_TILE_DEVICE auto get_bwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:219
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:145
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:138
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:142
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:346
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:323
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:144
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:368
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:153
static constexpr CK_TILE_DEVICE auto get_awarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:158
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:139
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:327
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:137
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:147
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:149
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:135
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:319
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:148
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:133
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:150
Definition: warp_gemm_attribute_mfma.hpp:479
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:491
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:494
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:486
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:487
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:483
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:564
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:480
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:554
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:493
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:484
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:490
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:482
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:496
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:492
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:488
Definition: warp_gemm_attribute_mfma.hpp:395
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:470
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma.hpp:412
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:400
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:405
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:411
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:414
decltype(get_warp_dstr_encoding< Impl::kAMLane >()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:447
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:404
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:410
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:401
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:398
static constexpr auto get_warp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:420
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:402
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:397
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:406
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:409
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:396
decltype(get_warp_dstr_encoding< Impl::kBNLane >()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:446
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:408
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:460
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: debug.hpp:67
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192