/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File
warp_gemm_attribute_mfma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // Number of groups of consecutive elements to fill in a ABKLane
13 {
14  Single = 1,
15  Double = 2,
16  Quad = 4,
17  Invalid = -1
18 };
19 
20 template <typename WarpGemmAttributeMfmaImpl_,
23 {
25  static constexpr auto AttrNumAccess = AttrNumAccess_;
26  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
27 
28  using ADataType = typename Impl::ADataType;
29  using BDataType = typename Impl::BDataType;
30  using CDataType = typename Impl::CDataType;
31 
32  using AVecType = typename Impl::AVecType;
33  using BVecType = typename Impl::BVecType;
34  using CVecType = typename Impl::CVecType;
35 
36  static constexpr index_t kM = Impl::kM;
37  static constexpr index_t kN = Impl::kN;
38  static constexpr index_t kK = Impl::kK;
39  static constexpr index_t kKPerThread = Impl::kABKPerLane;
40  static constexpr index_t kCMLane = Impl::kCMLane;
41 
42  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
43 
44  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
45  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
46 
47  template <index_t kMNLane>
48  static constexpr auto get_warp_dstr_encoding()
49  {
50  if constexpr(AttrNumAccessV == 1)
51  {
53  sequence<>,
58  sequence<1>>{};
59  }
60  else
61  {
62  static_assert(kKPerThread % AttrNumAccessV == 0,
63  "kKPerThread must be divisible by NumAccess");
65  sequence<>,
67  sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
71  sequence<0, 2>>{};
72  }
73  }
74  using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
75  using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
76 
78  sequence<>,
85 
86  // c_vec += a_vec * b_vec
87  template <bool post_nop_ = false>
89  const AVecType& a_vec,
90  const BVecType& b_vec,
91  bool_constant<post_nop_> = {}) const
92  {
93  Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
94  }
95 
96  // c_vec = a_vec * b_vec
97  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
98  {
99  return Impl{}(a_vec, b_vec);
100  }
101 };
102 
103 template <typename WarpGemmAttributeMfmaImpl_,
104  index_t kKIter,
107 {
108  static_assert(kKIter > 0, "wrong!");
109 
111  static constexpr auto AttrNumAccess = AttrNumAccess_;
112  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
113 
114  using ADataType = typename Impl::ADataType;
115  using BDataType = typename Impl::BDataType;
116  using CDataType = typename Impl::CDataType;
117 
118  using AVecType =
120  using BVecType =
122  using CVecType = typename Impl::CVecType;
123 
124  static constexpr index_t kM = Impl::kM;
125  static constexpr index_t kN = Impl::kN;
126  static constexpr index_t kK = Impl::kK * kKIter;
127  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
128 
129  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
130 
131  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
132  "Multi-block on both M & N directions is not supported");
133 
134  CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
135  {
136  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
137  {
138  if constexpr(AttrNumAccessV == 1)
139  {
141  sequence<>,
146  sequence<2>,
147  sequence<1>>{};
148  }
149  else
150  {
151  static_assert(kKPerThread % AttrNumAccessV == 0,
152  "kKPerThread must be divisible by NumAccess");
154  sequence<>,
157  Impl::kABKLane,
158  Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
162  sequence<0, 2>>{};
163  }
164  }
165  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
166  {
167  static_assert(AttrNumAccessV == 1,
168  "Multiple access is not supported when using multi-block");
169  // each M blocks share the same data
176  sequence<2>,
177  sequence<1>>{};
178  }
179  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
180  {
181  static_assert(AttrNumAccessV == 1,
182  "Multiple access is not supported when using multi-block");
183  // single block to multi-block thread mapping
185  sequence<>,
190  sequence<2>,
191  sequence<1>>{};
192  }
193  }
194 
195  CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
196  {
197  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
198  {
199  if constexpr(AttrNumAccessV == 1)
200  {
202  sequence<>,
207  sequence<2>,
208  sequence<1>>{};
209  }
210  else
211  {
212 
213  static_assert(kKPerThread % AttrNumAccessV == 0,
214  "kKPerThread must be divisible by NumAccess");
216  sequence<>,
219  Impl::kABKLane,
220  Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
224  sequence<0, 2>>{};
225  }
226  }
227  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
228  {
229  static_assert(AttrNumAccessV == 1,
230  "Multiple access is not supported when using multi-block");
231  // single block to multi-block thread mapping
233  sequence<>,
238  sequence<2>,
239  sequence<1>>{};
240  }
241  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
242  {
243  static_assert(AttrNumAccessV == 1,
244  "Multiple access is not supported when using multi-block");
245  // each N blocks share the same data
252  sequence<2>,
253  sequence<1>>{};
254  }
255  }
256 
257  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
258  {
259  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
260  {
262  sequence<>,
268  sequence<0, 2>>{};
269  }
270  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
271  {
273  sequence<>,
279  sequence<0, 2>>{};
280  }
281  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
282  {
284  sequence<>,
285  tuple<
291  sequence<0, 2>>{};
292  }
293  }
294 
296 
298 
300 
301  // c_vec += a_vec * b_vec
302  template <bool post_nop_ = false>
304  const AVecType& a_vec,
305  const BVecType& b_vec,
306  bool_constant<post_nop_> = {}) const
307  {
310 
311  static_for<0, kKIter, 1>{}([&](auto iKIter) {
312  Impl{}(c_vec,
313  reinterpret_cast<const buf_a&>(a_vec)
314  .template get_as<typename Impl::AVecType>()[iKIter],
315  reinterpret_cast<const buf_b&>(b_vec)
316  .template get_as<typename Impl::BVecType>()[iKIter],
317  bool_constant<post_nop_>{});
318  });
319  }
320 
321  template <index_t iKIter, bool post_nop_ = false>
323  const AVecType& a_vec,
324  const BVecType& b_vec,
326  bool_constant<post_nop_> = {}) const
327  {
330 
331  static_assert(iKIter < kKIter);
332 
333  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
334  Impl{}(c_vec,
335  reinterpret_cast<const buf_a&>(a_vec)
336  .template get_as<typename Impl::AVecType>()[iKIter],
337  reinterpret_cast<const buf_b&>(b_vec)
338  .template get_as<typename Impl::BVecType>()[iKIter],
339  bool_constant<post_nop_>{});
340  //});
341  }
342 
343  // c_vec = a_vec * b_vec
344  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
345  {
346  constexpr auto I0 = number<0>{};
349 
350  // c = a * b
351  auto c_vec = Impl{}(
352  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
353  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
354 
355  // c += a * b
356  static_for<1, kKIter, 1>{}([&](auto iKIter) {
357  Impl{}(c_vec,
358  reinterpret_cast<const buf_a&>(a_vec)
359  .template get_as<typename Impl::AVecType>()[iKIter],
360  reinterpret_cast<const buf_b&>(b_vec)
361  .template get_as<typename Impl::BVecType>()[iKIter]);
362  });
363 
364  return c_vec;
365  }
366 };
367 
368 template <typename WarpGemmAttributeMfmaImpl_,
371 {
373  static constexpr auto AttrNumAccess = AttrNumAccess_;
374  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
375 
376  using ADataType = typename Impl::BDataType;
377  using BDataType = typename Impl::ADataType;
378  using CDataType = typename Impl::CDataType;
379 
380  using AVecType = typename Impl::BVecType;
381  using BVecType = typename Impl::AVecType;
382  using CVecType = typename Impl::CVecType;
383 
384  static constexpr index_t kM = Impl::kN;
385  static constexpr index_t kN = Impl::kM;
386  static constexpr index_t kK = Impl::kK;
387  static constexpr index_t kKPerThread = Impl::kABKPerLane;
388  static constexpr index_t kCMLane = Impl::kCMLane;
389 
390  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
391 
392  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
393  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
394 
395  template <index_t kMNLane>
396  static constexpr auto get_warp_dstr_encoding()
397  {
398  if constexpr(AttrNumAccessV == 1)
399  {
401  sequence<>,
405  sequence<2>,
406  sequence<1>>{};
407  }
408  else
409  {
410  static_assert(kKPerThread % AttrNumAccessV == 0,
411  "kKPerThread must be divisible by NumAccess");
413  sequence<>,
415  sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
419  sequence<0, 2>>{};
420  }
421  }
422  using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
423  using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
424 
426  sequence<>,
433 
434  // c_vec += a_vec * b_vec
435  template <bool post_nop_ = false>
437  const AVecType& a_vec,
438  const BVecType& b_vec,
439  bool_constant<post_nop_> = {}) const
440  {
441  // swap A and B
442  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
443  }
444 
445  // c_vec = a_vec * b_vec
446  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
447  {
448  // swap A and B
449  return Impl{}(b_vec, a_vec);
450  }
451 };
452 
453 template <typename WarpGemmAttributeMfmaImpl_, index_t SFactor_ = 2>
455 {
457 
458  using ADataType = typename Impl::BDataType;
459  using BDataType = typename Impl::ADataType;
460  using CDataType = typename Impl::CDataType;
461 
462  using AVecType = typename Impl::BVecType;
463  using BVecType = typename Impl::AVecType;
464  using CVecType = typename Impl::CVecType;
465 
466  static constexpr index_t kM = Impl::kN;
467  static constexpr index_t kN = Impl::kM;
468  static constexpr index_t kK = Impl::kK;
469  static constexpr index_t kKPerThread = Impl::kABKPerLane;
470  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
471 
472  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
473 
474  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
475  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
476 
478  sequence<>,
482  sequence<2>,
483  sequence<1>>;
484 #if 0
486  sequence<>,
487  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
488  Impl::kABKLane,
489  2,
490  Impl::kABKPerLane>,
494  sequence<2>,
495  sequence<1>>;
496 
498  sequence<>,
500  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
505 #else
506  // TODO: more test not only 32x32
508  sequence<>,
509  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
510  Impl::kCMLane,
511  SFactor,
512  Impl::kCM1PerLane>,
516  sequence<2>,
517  sequence<1>>;
518 
520  sequence<>,
522  sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
527 #endif
528  template <bool post_nop_ = false>
529  // c_vec += a_vec * b_vec
531  const AVecType& a_vec,
532  const BVecType& b_vec,
533  bool_constant<post_nop_> = {}) const
534  {
535  // swap A and B
536  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
537  }
538 
539  // c_vec = a_vec * b_vec
540  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
541  {
542  // swap A and B
543  return Impl{}(b_vec, a_vec);
544  }
545 };
546 
547 template <typename WarpGemmAttributeMfmaImpl_,
548  index_t kKIter,
551 {
553  static constexpr auto AttrNumAccess = AttrNumAccess_;
554 
555  // swap A and B
556  using ADataType = typename Impl::BDataType;
557  using BDataType = typename Impl::ADataType;
558  using CDataType = typename Impl::CDataType;
559 
560  using AVecType =
562  using BVecType =
564  using CVecType = typename Impl::CVecType;
565 
566  static constexpr index_t kM = Impl::kN;
567  static constexpr index_t kN = Impl::kM;
568  static constexpr index_t kK = Impl::kK * kKIter;
569  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
570 
571  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
572 
573  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
574  "Multi-block on both M & N directions is not supported");
575 
576  CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
577  {
580  }
581 
582  CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
583  {
586  }
587 
588  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
589  {
590  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
591  {
593  sequence<>,
599  sequence<0, 2>>{};
600  }
601  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
602  {
604  sequence<>,
610  sequence<0, 2>>{};
611  }
612  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
613  {
615  sequence<>,
616  tuple<
622  sequence<0, 2>>{};
623  }
624  }
625 
627 
629 
631 
632  template <bool post_nop_ = false>
633  // c_vec += a_vec * b_vec
635  const AVecType& a_vec,
636  const BVecType& b_vec,
637  bool_constant<post_nop_> = {}) const
638  {
641  // swap A and B, value and type
642  static_for<0, kKIter, 1>{}([&](auto iKIter) {
643  Impl{}(c_vec,
644  reinterpret_cast<const buf_b&>(b_vec)
645  .template get_as<typename Impl::BVecType>()[iKIter],
646  reinterpret_cast<const buf_a&>(a_vec)
647  .template get_as<typename Impl::AVecType>()[iKIter],
648  bool_constant<post_nop_>{});
649  });
650  }
651 
652  template <index_t iKIter, bool post_nop_ = false>
653  // c_vec += a_vec * b_vec
655  const AVecType& a_vec,
656  const BVecType& b_vec,
658  bool_constant<post_nop_> = {}) const
659  {
662 
663  static_assert(iKIter < kKIter);
664  // swap A and B, value and type
665  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
666  Impl{}(c_vec,
667  reinterpret_cast<const buf_b&>(b_vec)
668  .template get_as<typename Impl::BVecType>()[iKIter],
669  reinterpret_cast<const buf_a&>(a_vec)
670  .template get_as<typename Impl::AVecType>()[iKIter],
671  bool_constant<post_nop_>{});
672  //});
673  }
674 
675  // c_vec = a_vec * b_vec
676  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
677  {
678  constexpr auto I0 = number<0>{};
681 
682  // swap A and B, value and type
683  auto c_vec = Impl{}(
684  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
685  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
686 
687  static_for<1, kKIter, 1>{}([&](auto iKIter) {
688  Impl{}(c_vec,
689  reinterpret_cast<const buf_b&>(b_vec)
690  .template get_as<typename Impl::BVecType>()[iKIter],
691  reinterpret_cast<const buf_a&>(a_vec)
692  .template get_as<typename Impl::AVecType>()[iKIter]);
693  });
694 
695  return c_vec;
696  }
697 };
698 
699 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
701 {
703 
704  // swap A and B
705  using ADataType = typename Impl::BDataType;
706  using BDataType = typename Impl::ADataType;
707  using CDataType = typename Impl::CDataType;
708 
709  using AVecType =
711  using BVecType =
713  using CVecType = typename Impl::CVecType;
714 
715  static constexpr index_t kM = Impl::kN;
716  static constexpr index_t kN = Impl::kM;
717  static constexpr index_t kK = Impl::kK * kKIter;
718  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
719  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
720 
721  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
722 
723  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
724  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
725 
727  sequence<>,
731  sequence<2>,
732  sequence<1>>;
733 #if 0
735  sequence<>,
736  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
737  Impl::kABKLane,
738  2,
739  Impl::kABKPerLane>,
743  sequence<2>,
744  sequence<1>>;
745 
747  sequence<>,
749  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
754 #else
755  // TODO: more test not only 32x32
757  sequence<>,
758  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
759  Impl::kCMLane,
760  SFactor,
761  Impl::kCM1PerLane>,
765  sequence<2>,
766  sequence<1>>;
767 
769  sequence<>,
771  sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
776 #endif
777  // c_vec += a_vec * b_vec
778  template <bool post_nop_ = false>
780  const AVecType& a_vec,
781  const BVecType& b_vec,
782  bool_constant<post_nop_> = {}) const
783  {
786  // swap A and B, value and type
787  static_for<0, kKIter, 1>{}([&](auto iKIter) {
788  Impl{}(c_vec,
789  reinterpret_cast<const buf_b&>(b_vec)
790  .template get_as<typename Impl::BVecType>()[iKIter],
791  reinterpret_cast<const buf_a&>(a_vec)
792  .template get_as<typename Impl::AVecType>()[iKIter],
793  bool_constant<post_nop_>{});
794  });
795  }
796 
797  template <index_t iKIter, bool post_nop_ = false>
799  const AVecType& a_vec,
800  const BVecType& b_vec,
802  bool_constant<post_nop_> = {}) const
803  {
806 
807  static_assert(iKIter < kKIter);
808  // swap A and B, value and type
809  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
810  Impl{}(c_vec,
811  reinterpret_cast<const buf_b&>(b_vec)
812  .template get_as<typename Impl::BVecType>()[iKIter],
813  reinterpret_cast<const buf_a&>(a_vec)
814  .template get_as<typename Impl::AVecType>()[iKIter],
815  bool_constant<post_nop_>{});
816  //});
817  }
818 
819  // c_vec = a_vec * b_vec
820  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
821  {
824  constexpr auto I0 = number<0>{};
825 
826  // swap A and B, value and type
827  auto c_vec = Impl{}(
828  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
829  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
830 
831  static_for<1, kKIter, 1>{}([&](auto iKIter) {
832  Impl{}(c_vec,
833  reinterpret_cast<const buf_b&>(b_vec)
834  .template get_as<typename Impl::BVecType>()[iKIter],
835  reinterpret_cast<const buf_a&>(a_vec)
836  .template get_as<typename Impl::AVecType>()[iKIter]);
837  });
838 
839  return c_vec;
840  }
841 };
842 
843 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
845 {
847 
848  using ADataType = typename Impl::ADataType;
849  using BDataType = typename Impl::BDataType;
850  using CDataType = typename Impl::CDataType;
851 
852  using AVecType =
854  using BVecType =
856  using CVecType = typename Impl::CVecType;
857 
858  static constexpr index_t kM = Impl::kM;
859  static constexpr index_t kN = Impl::kN;
860  static constexpr index_t kK = Impl::kK * kKIter;
861  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
862  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
863 
864  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
865 
866  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
867  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
868 
870  sequence<>,
871  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
872  Impl::kCMLane,
873  SFactor,
874  Impl::kCM1PerLane>,
878  sequence<2>,
879  sequence<1>>;
880 
882  sequence<>,
886  sequence<2>,
887  sequence<1>>;
888 
890  sequence<>,
891  tuple<sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>,
897 
898  // c_vec += a_vec * b_vec
899  template <bool post_nop_ = false>
901  const AVecType& a_vec,
902  const BVecType& b_vec,
903  bool_constant<post_nop_> = {}) const
904  {
907 
908  static_for<0, kKIter, 1>{}([&](auto iKIter) {
909  Impl{}(c_vec,
910  reinterpret_cast<const buf_a&>(a_vec)
911  .template get_as<typename Impl::AVecType>()[iKIter],
912  reinterpret_cast<const buf_b&>(b_vec)
913  .template get_as<typename Impl::BVecType>()[iKIter],
914  bool_constant<post_nop_>{});
915  });
916  }
917 
918  template <index_t iKIter, bool post_nop_ = false>
920  const AVecType& a_vec,
921  const BVecType& b_vec,
923  bool_constant<post_nop_> = {}) const
924  {
927 
928  static_assert(iKIter < kKIter);
929 
930  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
931  Impl{}(c_vec,
932  reinterpret_cast<const buf_a&>(a_vec)
933  .template get_as<typename Impl::AVecType>()[iKIter],
934  reinterpret_cast<const buf_b&>(b_vec)
935  .template get_as<typename Impl::BVecType>()[iKIter],
936  bool_constant<post_nop_>{});
937  //});
938  }
939 
940  // c_vec = a_vec * b_vec
941  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
942  {
943  constexpr auto I0 = number<0>{};
946 
947  auto c_vec = Impl{}(
948  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
949  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
950 
951  static_for<1, kKIter, 1>{}([&](auto iKIter) {
952  Impl{}(c_vec,
953  reinterpret_cast<const buf_a&>(a_vec)
954  .template get_as<typename Impl::AVecType>()[iKIter],
955  reinterpret_cast<const buf_b&>(b_vec)
956  .template get_as<typename Impl::BVecType>()[iKIter]);
957  });
958 
959  return c_vec;
960  }
961 };
962 
963 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
WGAttrNumAccessEnum
Definition: warp_gemm_attribute_mfma.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:83
Definition: warp_gemm_attribute_mfma.hpp:23
static constexpr auto get_warp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:48
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:38
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:29
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma.hpp:40
typename Impl::AVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:32
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:34
decltype(get_warp_dstr_encoding< Impl::kAMLane >()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:74
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:25
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:30
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:97
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:42
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:26
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:36
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:24
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:88
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:39
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:37
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:28
decltype(get_warp_dstr_encoding< Impl::kBNLane >()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:75
typename Impl::BVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:33
Definition: warp_gemm_attribute_mfma.hpp:845
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:846
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:919
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:859
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:858
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:860
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:941
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:861
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:855
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:900
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:849
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:856
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:862
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:864
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:848
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:850
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:853
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:705
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:706
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:718
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:710
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:719
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:717
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:779
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:713
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:702
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:707
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:716
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:820
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:798
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:712
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:715
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:721
Definition: warp_gemm_attribute_mfma.hpp:551
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:564
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:588
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:563
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:569
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:558
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:571
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:556
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:566
static constexpr CK_TILE_DEVICE auto get_awarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:576
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:568
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:654
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:567
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:626
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:676
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:628
static constexpr CK_TILE_DEVICE auto get_bwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:582
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:630
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:552
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:553
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:561
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:634
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:557
Definition: warp_gemm_attribute_mfma.hpp:107
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:111
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:297
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:257
static constexpr CK_TILE_DEVICE auto get_bwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:195
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:122
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:115
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:119
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:322
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:299
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:121
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:344
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:129
static constexpr CK_TILE_DEVICE auto get_awarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:134
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:116
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:303
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:114
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:124
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:126
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:112
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:295
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:125
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:110
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:127
Definition: warp_gemm_attribute_mfma.hpp:455
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:467
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:470
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:462
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:463
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:459
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:540
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:456
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:530
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:469
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:460
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:466
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:458
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:472
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:468
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:464
Definition: warp_gemm_attribute_mfma.hpp:371
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:446
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma.hpp:388
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:376
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:381
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:387
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:390
decltype(get_warp_dstr_encoding< Impl::kAMLane >()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:423
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:380
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:386
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:377
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:374
static constexpr auto get_warp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:396
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:378
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:373
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:382
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:385
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:372
decltype(get_warp_dstr_encoding< Impl::kBNLane >()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:422
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:384
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:436
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: debug.hpp:67
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192