/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File
warp_gemm_attribute_mfma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // Number of groups of consecutive elements to fill in a ABKLane
13 {
14  Single = 1,
15  Double = 2,
16  Quad = 4,
17  Invalid = -1
18 };
19 
20 template <typename WarpGemmAttributeMfmaImpl_,
23 {
25  static constexpr auto AttrNumAccess = AttrNumAccess_;
26  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
27 
28  using ADataType = typename Impl::ADataType;
29  using BDataType = typename Impl::BDataType;
30  using CDataType = typename Impl::CDataType;
31 
32  using AVecType = typename Impl::AVecType;
33  using BVecType = typename Impl::BVecType;
34  using CVecType = typename Impl::CVecType;
35 
36  static constexpr index_t kM = Impl::kM;
37  static constexpr index_t kN = Impl::kN;
38  static constexpr index_t kK = Impl::kK;
39  static constexpr index_t kKPerThread = Impl::kABKPerLane;
40  static constexpr index_t kCMLane = Impl::kCMLane;
41 
42  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
43 
44  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
45  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
46 
47  template <index_t kMNLane>
48  static constexpr auto get_warp_dstr_encoding()
49  {
50  if constexpr(AttrNumAccessV == 1)
51  {
53  sequence<>,
58  sequence<1>>{};
59  }
60  else
61  {
62  static_assert(kKPerThread % AttrNumAccessV == 0,
63  "kKPerThread must be divisible by NumAccess");
65  sequence<>,
67  sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
71  sequence<0, 2>>{};
72  }
73  }
74  using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
75  using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
76 
78  sequence<>,
85 
86  // c_vec += a_vec * b_vec
87  template <bool post_nop_ = false>
89  const AVecType& a_vec,
90  const BVecType& b_vec,
91  bool_constant<post_nop_> = {}) const
92  {
93  Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
94  }
95 
96  // c_vec = a_vec * b_vec
97  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
98  {
99  return Impl{}(a_vec, b_vec);
100  }
101 };
102 
103 template <typename WarpGemmAttributeMfmaImpl_,
104  index_t kKIter,
107 {
108  static_assert(kKIter > 0, "wrong!");
109 
111  static constexpr auto AttrNumAccess = AttrNumAccess_;
112  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
113 
114  using ADataType = typename Impl::ADataType;
115  using BDataType = typename Impl::BDataType;
116  using CDataType = typename Impl::CDataType;
117 
118  using AVecType =
120  using BVecType =
122  using CVecType = typename Impl::CVecType;
123 
124  static constexpr index_t kM = Impl::kM;
125  static constexpr index_t kN = Impl::kN;
126  static constexpr index_t kK = Impl::kK * kKIter;
127  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
128  static constexpr index_t kCMLane = Impl::kCMLane;
129 
130  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
131 
132  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
133  "Multi-block on both M & N directions is not supported");
134 
135  CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
136  {
137  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
138  {
139  if constexpr(AttrNumAccessV == 1)
140  {
142  sequence<>,
147  sequence<2>,
148  sequence<1>>{};
149  }
150  else
151  {
152  static_assert(kKPerThread % AttrNumAccessV == 0,
153  "kKPerThread must be divisible by NumAccess");
155  sequence<>,
158  Impl::kABKLane,
159  Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
163  sequence<0, 2>>{};
164  }
165  }
166  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
167  {
168  static_assert(AttrNumAccessV == 1,
169  "Multiple access is not supported when using multi-block");
170  // each M blocks share the same data
177  sequence<2>,
178  sequence<1>>{};
179  }
180  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
181  {
182  static_assert(AttrNumAccessV == 1,
183  "Multiple access is not supported when using multi-block");
184  // single block to multi-block thread mapping
186  sequence<>,
191  sequence<2>,
192  sequence<1>>{};
193  }
194  }
195 
196  CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
197  {
198  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
199  {
200  if constexpr(AttrNumAccessV == 1)
201  {
203  sequence<>,
208  sequence<2>,
209  sequence<1>>{};
210  }
211  else
212  {
213 
214  static_assert(kKPerThread % AttrNumAccessV == 0,
215  "kKPerThread must be divisible by NumAccess");
217  sequence<>,
220  Impl::kABKLane,
221  Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
225  sequence<0, 2>>{};
226  }
227  }
228  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
229  {
230  static_assert(AttrNumAccessV == 1,
231  "Multiple access is not supported when using multi-block");
232  // single block to multi-block thread mapping
234  sequence<>,
239  sequence<2>,
240  sequence<1>>{};
241  }
242  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
243  {
244  static_assert(AttrNumAccessV == 1,
245  "Multiple access is not supported when using multi-block");
246  // each N blocks share the same data
253  sequence<2>,
254  sequence<1>>{};
255  }
256  }
257 
258  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
259  {
260  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
261  {
263  sequence<>,
269  sequence<0, 2>>{};
270  }
271  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
272  {
274  sequence<>,
280  sequence<0, 2>>{};
281  }
282  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
283  {
285  sequence<>,
286  tuple<
292  sequence<0, 2>>{};
293  }
294  }
295 
297 
299 
301 
302  // c_vec += a_vec * b_vec
303  template <bool post_nop_ = false>
305  const AVecType& a_vec,
306  const BVecType& b_vec,
307  bool_constant<post_nop_> = {}) const
308  {
311 
312  static_for<0, kKIter, 1>{}([&](auto iKIter) {
313  Impl{}(c_vec,
314  reinterpret_cast<const buf_a&>(a_vec)
315  .template get_as<typename Impl::AVecType>()[iKIter],
316  reinterpret_cast<const buf_b&>(b_vec)
317  .template get_as<typename Impl::BVecType>()[iKIter],
318  bool_constant<post_nop_>{});
319  });
320  }
321 
322  template <index_t iKIter, bool post_nop_ = false>
324  const AVecType& a_vec,
325  const BVecType& b_vec,
327  bool_constant<post_nop_> = {}) const
328  {
331 
332  static_assert(iKIter < kKIter);
333 
334  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
335  Impl{}(c_vec,
336  reinterpret_cast<const buf_a&>(a_vec)
337  .template get_as<typename Impl::AVecType>()[iKIter],
338  reinterpret_cast<const buf_b&>(b_vec)
339  .template get_as<typename Impl::BVecType>()[iKIter],
340  bool_constant<post_nop_>{});
341  //});
342  }
343 
344  // c_vec = a_vec * b_vec
345  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
346  {
347  constexpr auto I0 = number<0>{};
350 
351  // c = a * b
352  auto c_vec = Impl{}(
353  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
354  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
355 
356  // c += a * b
357  static_for<1, kKIter, 1>{}([&](auto iKIter) {
358  Impl{}(c_vec,
359  reinterpret_cast<const buf_a&>(a_vec)
360  .template get_as<typename Impl::AVecType>()[iKIter],
361  reinterpret_cast<const buf_b&>(b_vec)
362  .template get_as<typename Impl::BVecType>()[iKIter]);
363  });
364 
365  return c_vec;
366  }
367 };
368 
369 template <typename WarpGemmAttributeMfmaImpl_,
372 {
374  static constexpr auto AttrNumAccess = AttrNumAccess_;
375  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
376 
377  using ADataType = typename Impl::BDataType;
378  using BDataType = typename Impl::ADataType;
379  using CDataType = typename Impl::CDataType;
380 
381  using AVecType = typename Impl::BVecType;
382  using BVecType = typename Impl::AVecType;
383  using CVecType = typename Impl::CVecType;
384 
385  static constexpr index_t kM = Impl::kN;
386  static constexpr index_t kN = Impl::kM;
387  static constexpr index_t kK = Impl::kK;
388  static constexpr index_t kKPerThread = Impl::kABKPerLane;
389  static constexpr index_t kCMLane = Impl::kCMLane;
390 
391  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
392 
393  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
394  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
395 
396  template <index_t kMNLane>
397  static constexpr auto get_warp_dstr_encoding()
398  {
399  if constexpr(AttrNumAccessV == 1)
400  {
402  sequence<>,
406  sequence<2>,
407  sequence<1>>{};
408  }
409  else
410  {
411  static_assert(kKPerThread % AttrNumAccessV == 0,
412  "kKPerThread must be divisible by NumAccess");
414  sequence<>,
416  sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
420  sequence<0, 2>>{};
421  }
422  }
423  using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
424  using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
425 
427  sequence<>,
434 
435  // c_vec += a_vec * b_vec
436  template <bool post_nop_ = false>
438  const AVecType& a_vec,
439  const BVecType& b_vec,
440  bool_constant<post_nop_> = {}) const
441  {
442  // swap A and B
443  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
444  }
445 
446  // c_vec = a_vec * b_vec
447  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
448  {
449  // swap A and B
450  return Impl{}(b_vec, a_vec);
451  }
452 };
453 
454 template <typename WarpGemmAttributeMfmaImpl_, index_t SFactor_ = 2>
456 {
458 
459  using ADataType = typename Impl::BDataType;
460  using BDataType = typename Impl::ADataType;
461  using CDataType = typename Impl::CDataType;
462 
463  using AVecType = typename Impl::BVecType;
464  using BVecType = typename Impl::AVecType;
465  using CVecType = typename Impl::CVecType;
466 
467  static constexpr index_t kM = Impl::kN;
468  static constexpr index_t kN = Impl::kM;
469  static constexpr index_t kK = Impl::kK;
470  static constexpr index_t kKPerThread = Impl::kABKPerLane;
471  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
472 
473  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
474 
475  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
476  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
477 
479  sequence<>,
483  sequence<2>,
484  sequence<1>>;
485 #if 0
487  sequence<>,
488  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
489  Impl::kABKLane,
490  2,
491  Impl::kABKPerLane>,
495  sequence<2>,
496  sequence<1>>;
497 
499  sequence<>,
501  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
506 #else
507  // TODO: more test not only 32x32
509  sequence<>,
510  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
511  Impl::kCMLane,
512  SFactor,
513  Impl::kCM1PerLane>,
517  sequence<2>,
518  sequence<1>>;
519 
521  sequence<>,
523  sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
528 #endif
529  template <bool post_nop_ = false>
530  // c_vec += a_vec * b_vec
532  const AVecType& a_vec,
533  const BVecType& b_vec,
534  bool_constant<post_nop_> = {}) const
535  {
536  // swap A and B
537  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
538  }
539 
540  // c_vec = a_vec * b_vec
541  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
542  {
543  // swap A and B
544  return Impl{}(b_vec, a_vec);
545  }
546 };
547 
548 template <typename WarpGemmAttributeMfmaImpl_,
549  index_t kKIter,
552 {
554  static constexpr auto AttrNumAccess = AttrNumAccess_;
555 
556  // swap A and B
557  using ADataType = typename Impl::BDataType;
558  using BDataType = typename Impl::ADataType;
559  using CDataType = typename Impl::CDataType;
560 
561  using AVecType =
563  using BVecType =
565  using CVecType = typename Impl::CVecType;
566 
567  static constexpr index_t kM = Impl::kN;
568  static constexpr index_t kN = Impl::kM;
569  static constexpr index_t kK = Impl::kK * kKIter;
570  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
571 
572  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
573 
574  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
575  "Multi-block on both M & N directions is not supported");
576 
577  CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
578  {
581  }
582 
583  CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
584  {
587  }
588 
589  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
590  {
591  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
592  {
594  sequence<>,
600  sequence<0, 2>>{};
601  }
602  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
603  {
605  sequence<>,
611  sequence<0, 2>>{};
612  }
613  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
614  {
616  sequence<>,
617  tuple<
623  sequence<0, 2>>{};
624  }
625  }
626 
628 
630 
632 
633  template <bool post_nop_ = false>
634  // c_vec += a_vec * b_vec
636  const AVecType& a_vec,
637  const BVecType& b_vec,
638  bool_constant<post_nop_> = {}) const
639  {
642  // swap A and B, value and type
643  static_for<0, kKIter, 1>{}([&](auto iKIter) {
644  Impl{}(c_vec,
645  reinterpret_cast<const buf_b&>(b_vec)
646  .template get_as<typename Impl::BVecType>()[iKIter],
647  reinterpret_cast<const buf_a&>(a_vec)
648  .template get_as<typename Impl::AVecType>()[iKIter],
649  bool_constant<post_nop_>{});
650  });
651  }
652 
653  template <index_t iKIter, bool post_nop_ = false>
654  // c_vec += a_vec * b_vec
656  const AVecType& a_vec,
657  const BVecType& b_vec,
659  bool_constant<post_nop_> = {}) const
660  {
663 
664  static_assert(iKIter < kKIter);
665  // swap A and B, value and type
666  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
667  Impl{}(c_vec,
668  reinterpret_cast<const buf_b&>(b_vec)
669  .template get_as<typename Impl::BVecType>()[iKIter],
670  reinterpret_cast<const buf_a&>(a_vec)
671  .template get_as<typename Impl::AVecType>()[iKIter],
672  bool_constant<post_nop_>{});
673  //});
674  }
675 
676  // c_vec = a_vec * b_vec
677  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
678  {
679  constexpr auto I0 = number<0>{};
682 
683  // swap A and B, value and type
684  auto c_vec = Impl{}(
685  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
686  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
687 
688  static_for<1, kKIter, 1>{}([&](auto iKIter) {
689  Impl{}(c_vec,
690  reinterpret_cast<const buf_b&>(b_vec)
691  .template get_as<typename Impl::BVecType>()[iKIter],
692  reinterpret_cast<const buf_a&>(a_vec)
693  .template get_as<typename Impl::AVecType>()[iKIter]);
694  });
695 
696  return c_vec;
697  }
698 };
699 
700 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
702 {
704 
705  // swap A and B
706  using ADataType = typename Impl::BDataType;
707  using BDataType = typename Impl::ADataType;
708  using CDataType = typename Impl::CDataType;
709 
710  using AVecType =
712  using BVecType =
714  using CVecType = typename Impl::CVecType;
715 
716  static constexpr index_t kM = Impl::kN;
717  static constexpr index_t kN = Impl::kM;
718  static constexpr index_t kK = Impl::kK * kKIter;
719  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
720  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
721 
722  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
723 
724  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
725  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
726 
728  sequence<>,
732  sequence<2>,
733  sequence<1>>;
734 #if 0
736  sequence<>,
737  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
738  Impl::kABKLane,
739  2,
740  Impl::kABKPerLane>,
744  sequence<2>,
745  sequence<1>>;
746 
748  sequence<>,
750  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
755 #else
756  // TODO: more test not only 32x32
758  sequence<>,
759  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
760  Impl::kCMLane,
761  SFactor,
762  Impl::kCM1PerLane>,
766  sequence<2>,
767  sequence<1>>;
768 
770  sequence<>,
772  sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
777 #endif
778  // c_vec += a_vec * b_vec
779  template <bool post_nop_ = false>
781  const AVecType& a_vec,
782  const BVecType& b_vec,
783  bool_constant<post_nop_> = {}) const
784  {
787  // swap A and B, value and type
788  static_for<0, kKIter, 1>{}([&](auto iKIter) {
789  Impl{}(c_vec,
790  reinterpret_cast<const buf_b&>(b_vec)
791  .template get_as<typename Impl::BVecType>()[iKIter],
792  reinterpret_cast<const buf_a&>(a_vec)
793  .template get_as<typename Impl::AVecType>()[iKIter],
794  bool_constant<post_nop_>{});
795  });
796  }
797 
798  template <index_t iKIter, bool post_nop_ = false>
800  const AVecType& a_vec,
801  const BVecType& b_vec,
803  bool_constant<post_nop_> = {}) const
804  {
807 
808  static_assert(iKIter < kKIter);
809  // swap A and B, value and type
810  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
811  Impl{}(c_vec,
812  reinterpret_cast<const buf_b&>(b_vec)
813  .template get_as<typename Impl::BVecType>()[iKIter],
814  reinterpret_cast<const buf_a&>(a_vec)
815  .template get_as<typename Impl::AVecType>()[iKIter],
816  bool_constant<post_nop_>{});
817  //});
818  }
819 
820  // c_vec = a_vec * b_vec
821  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
822  {
825  constexpr auto I0 = number<0>{};
826 
827  // swap A and B, value and type
828  auto c_vec = Impl{}(
829  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
830  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
831 
832  static_for<1, kKIter, 1>{}([&](auto iKIter) {
833  Impl{}(c_vec,
834  reinterpret_cast<const buf_b&>(b_vec)
835  .template get_as<typename Impl::BVecType>()[iKIter],
836  reinterpret_cast<const buf_a&>(a_vec)
837  .template get_as<typename Impl::AVecType>()[iKIter]);
838  });
839 
840  return c_vec;
841  }
842 };
843 
844 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
846 {
848 
849  using ADataType = typename Impl::ADataType;
850  using BDataType = typename Impl::BDataType;
851  using CDataType = typename Impl::CDataType;
852 
853  using AVecType =
855  using BVecType =
857  using CVecType = typename Impl::CVecType;
858 
859  static constexpr index_t kM = Impl::kM;
860  static constexpr index_t kN = Impl::kN;
861  static constexpr index_t kK = Impl::kK * kKIter;
862  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
863  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
864 
865  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
866 
867  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
868  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
869 
871  sequence<>,
872  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
873  Impl::kCMLane,
874  SFactor,
875  Impl::kCM1PerLane>,
879  sequence<2>,
880  sequence<1>>;
881 
883  sequence<>,
887  sequence<2>,
888  sequence<1>>;
889 
891  sequence<>,
892  tuple<sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>,
898 
899  // c_vec += a_vec * b_vec
900  template <bool post_nop_ = false>
902  const AVecType& a_vec,
903  const BVecType& b_vec,
904  bool_constant<post_nop_> = {}) const
905  {
908 
909  static_for<0, kKIter, 1>{}([&](auto iKIter) {
910  Impl{}(c_vec,
911  reinterpret_cast<const buf_a&>(a_vec)
912  .template get_as<typename Impl::AVecType>()[iKIter],
913  reinterpret_cast<const buf_b&>(b_vec)
914  .template get_as<typename Impl::BVecType>()[iKIter],
915  bool_constant<post_nop_>{});
916  });
917  }
918 
919  template <index_t iKIter, bool post_nop_ = false>
921  const AVecType& a_vec,
922  const BVecType& b_vec,
924  bool_constant<post_nop_> = {}) const
925  {
928 
929  static_assert(iKIter < kKIter);
930 
931  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
932  Impl{}(c_vec,
933  reinterpret_cast<const buf_a&>(a_vec)
934  .template get_as<typename Impl::AVecType>()[iKIter],
935  reinterpret_cast<const buf_b&>(b_vec)
936  .template get_as<typename Impl::BVecType>()[iKIter],
937  bool_constant<post_nop_>{});
938  //});
939  }
940 
941  // c_vec = a_vec * b_vec
942  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
943  {
944  constexpr auto I0 = number<0>{};
947 
948  auto c_vec = Impl{}(
949  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
950  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
951 
952  static_for<1, kKIter, 1>{}([&](auto iKIter) {
953  Impl{}(c_vec,
954  reinterpret_cast<const buf_a&>(a_vec)
955  .template get_as<typename Impl::AVecType>()[iKIter],
956  reinterpret_cast<const buf_b&>(b_vec)
957  .template get_as<typename Impl::BVecType>()[iKIter]);
958  });
959 
960  return c_vec;
961  }
962 };
963 
964 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
WGAttrNumAccessEnum
Definition: warp_gemm_attribute_mfma.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:83
Definition: warp_gemm_attribute_mfma.hpp:23
static constexpr auto get_warp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:48
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:38
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:29
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma.hpp:40
typename Impl::AVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:32
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:34
decltype(get_warp_dstr_encoding< Impl::kAMLane >()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:74
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:25
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:30
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:97
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:42
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:26
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:36
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:24
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:88
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:39
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:37
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:28
decltype(get_warp_dstr_encoding< Impl::kBNLane >()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:75
typename Impl::BVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:33
Definition: warp_gemm_attribute_mfma.hpp:846
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:847
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:920
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:860
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:859
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:861
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:942
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:862
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:856
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:901
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:850
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:857
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:863
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:865
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:849
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:851
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:854
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:706
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:707
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:719
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:711
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:720
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:718
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:780
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:714
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:703
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:708
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:717
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:821
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:799
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:713
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:716
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:722
Definition: warp_gemm_attribute_mfma.hpp:552
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:565
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:589
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:564
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:570
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:559
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:572
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:557
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:567
static constexpr CK_TILE_DEVICE auto get_awarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:577
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:569
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:655
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:568
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:627
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:677
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:629
static constexpr CK_TILE_DEVICE auto get_bwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:583
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:631
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:553
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:554
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:562
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:635
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:558
Definition: warp_gemm_attribute_mfma.hpp:107
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:111
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma.hpp:128
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:298
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:258
static constexpr CK_TILE_DEVICE auto get_bwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:196
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:122
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:115
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:119
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:323
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:300
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:121
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:345
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:130
static constexpr CK_TILE_DEVICE auto get_awarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:135
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:116
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:304
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:114
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:124
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:126
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:112
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:296
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:125
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:110
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:127
Definition: warp_gemm_attribute_mfma.hpp:456
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:468
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:471
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:463
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:464
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:460
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:541
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:457
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:531
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:470
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:461
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:467
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:459
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:473
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:469
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:465
Definition: warp_gemm_attribute_mfma.hpp:372
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:447
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma.hpp:389
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:377
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:382
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:388
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:391
decltype(get_warp_dstr_encoding< Impl::kAMLane >()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:424
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:381
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:387
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:378
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:375
static constexpr auto get_warp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:397
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:379
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:374
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:383
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:386
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:373
decltype(get_warp_dstr_encoding< Impl::kBNLane >()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:423
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:385
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:437
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: debug.hpp:67
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192