include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File

include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File#

Composable Kernel: include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File
warp_gemm_attribute_mfma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 template <typename WarpGemmAttributeMfmaImpl_>
13 {
15 
16  using ADataType = typename Impl::ADataType;
17  using BDataType = typename Impl::BDataType;
18  using CDataType = typename Impl::CDataType;
19 
20  using AVecType = typename Impl::AVecType;
21  using BVecType = typename Impl::BVecType;
22  using CVecType = typename Impl::CVecType;
23 
24  static constexpr index_t kM = Impl::kM;
25  static constexpr index_t kN = Impl::kN;
26  static constexpr index_t kK = Impl::kK;
27  static constexpr index_t kKPerThread = Impl::kABKPerLane;
28 
29  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
30 
31  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
32  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
33 
35  sequence<>,
40  sequence<1>>;
41 
43  sequence<>,
48  sequence<1>>;
49 
51  sequence<>,
58 
59  // c_vec += a_vec * b_vec
60  template <bool post_nop_ = false>
62  const AVecType& a_vec,
63  const BVecType& b_vec,
64  bool_constant<post_nop_> = {}) const
65  {
66  Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
67  }
68 
69  // c_vec = a_vec * b_vec
70  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
71  {
72  return Impl{}(a_vec, b_vec);
73  }
74 };
75 
76 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
78 {
79  static_assert(kKIter > 0, "wrong!");
80 
82 
83  using ADataType = typename Impl::ADataType;
84  using BDataType = typename Impl::BDataType;
85  using CDataType = typename Impl::CDataType;
86 
87  using AVecType =
89  using BVecType =
91  using CVecType = typename Impl::CVecType;
92 
93  static constexpr index_t kM = Impl::kM;
94  static constexpr index_t kN = Impl::kN;
95  static constexpr index_t kK = Impl::kK * kKIter;
96  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
97 
98  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
99 
100  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
101  "Multi-block on both M & N directions is not supported");
102 
103  CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
104  {
105  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
106  {
108  sequence<>,
113  sequence<2>,
114  sequence<1>>{};
115  }
116  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
117  {
118  // each M blocks share the same data
125  sequence<2>,
126  sequence<1>>{};
127  }
128  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
129  {
130  // single block to multi-block thread mapping
132  sequence<>,
137  sequence<2>,
138  sequence<1>>{};
139  }
140  }
141 
142  CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
143  {
144  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
145  {
147  sequence<>,
152  sequence<2>,
153  sequence<1>>{};
154  }
155  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
156  {
157  // single block to multi-block thread mapping
159  sequence<>,
164  sequence<2>,
165  sequence<1>>{};
166  }
167  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
168  {
169  // each N blocks share the same data
176  sequence<2>,
177  sequence<1>>{};
178  }
179  }
180 
181  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
182  {
183  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
184  {
186  sequence<>,
192  sequence<0, 2>>{};
193  }
194  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
195  {
197  sequence<>,
203  sequence<0, 2>>{};
204  }
205  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
206  {
208  sequence<>,
209  tuple<
215  sequence<0, 2>>{};
216  }
217  }
218 
220 
222 
224 
225  // c_vec += a_vec * b_vec
226  template <bool post_nop_ = false>
228  const AVecType& a_vec,
229  const BVecType& b_vec,
230  bool_constant<post_nop_> = {}) const
231  {
234 
235  static_for<0, kKIter, 1>{}([&](auto iKIter) {
236  Impl{}(c_vec,
237  reinterpret_cast<const buf_a&>(a_vec)
238  .template get_as<typename Impl::AVecType>()[iKIter],
239  reinterpret_cast<const buf_b&>(b_vec)
240  .template get_as<typename Impl::BVecType>()[iKIter],
241  bool_constant<post_nop_>{});
242  });
243  }
244 
245  template <index_t iKIter, bool post_nop_ = false>
247  const AVecType& a_vec,
248  const BVecType& b_vec,
250  bool_constant<post_nop_> = {}) const
251  {
254 
255  static_assert(iKIter < kKIter);
256 
257  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
258  Impl{}(c_vec,
259  reinterpret_cast<const buf_a&>(a_vec)
260  .template get_as<typename Impl::AVecType>()[iKIter],
261  reinterpret_cast<const buf_b&>(b_vec)
262  .template get_as<typename Impl::BVecType>()[iKIter],
263  bool_constant<post_nop_>{});
264  //});
265  }
266 
267  // c_vec = a_vec * b_vec
268  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
269  {
270  constexpr auto I0 = number<0>{};
273 
274  // c = a * b
275  auto c_vec = Impl{}(
276  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
277  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
278 
279  // c += a * b
280  static_for<1, kKIter, 1>{}([&](auto iKIter) {
281  Impl{}(c_vec,
282  reinterpret_cast<const buf_a&>(a_vec)
283  .template get_as<typename Impl::AVecType>()[iKIter],
284  reinterpret_cast<const buf_b&>(b_vec)
285  .template get_as<typename Impl::BVecType>()[iKIter]);
286  });
287 
288  return c_vec;
289  }
290 };
291 
292 template <typename WarpGemmAttributeMfmaImpl_>
294 {
296 
297  using ADataType = typename Impl::BDataType;
298  using BDataType = typename Impl::ADataType;
299  using CDataType = typename Impl::CDataType;
300 
301  using AVecType = typename Impl::BVecType;
302  using BVecType = typename Impl::AVecType;
303  using CVecType = typename Impl::CVecType;
304 
305  static constexpr index_t kM = Impl::kN;
306  static constexpr index_t kN = Impl::kM;
307  static constexpr index_t kK = Impl::kK;
308  static constexpr index_t kKPerThread = Impl::kABKPerLane;
309 
310  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
311 
312  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
313  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
314 
316  sequence<>,
320  sequence<2>,
321  sequence<1>>;
322 
324  sequence<>,
328  sequence<2>,
329  sequence<1>>;
330 
332  sequence<>,
339 
340  // c_vec += a_vec * b_vec
341  template <bool post_nop_ = false>
343  const AVecType& a_vec,
344  const BVecType& b_vec,
345  bool_constant<post_nop_> = {}) const
346  {
347  // swap A and B
348  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
349  }
350 
351  // c_vec = a_vec * b_vec
352  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
353  {
354  // swap A and B
355  return Impl{}(b_vec, a_vec);
356  }
357 };
358 
359 template <typename WarpGemmAttributeMfmaImpl_>
361 {
363 
364  using ADataType = typename Impl::BDataType;
365  using BDataType = typename Impl::ADataType;
366  using CDataType = typename Impl::CDataType;
367 
368  using AVecType = typename Impl::BVecType;
369  using BVecType = typename Impl::AVecType;
370  using CVecType = typename Impl::CVecType;
371 
372  static constexpr index_t kM = Impl::kN;
373  static constexpr index_t kN = Impl::kM;
374  static constexpr index_t kK = Impl::kK;
375  static constexpr index_t kKPerThread = Impl::kABKPerLane;
376 
377  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
378 
379  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
380  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
381 
383  sequence<>,
387  sequence<2>,
388  sequence<1>>;
389 
391  sequence<>,
392  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
393  Impl::kABKLane,
394  2,
395  Impl::kABKPerLane>,
399  sequence<2>,
400  sequence<1>>;
401 
403  sequence<>,
405  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
410 
411  template <bool post_nop_ = false>
412  // c_vec += a_vec * b_vec
414  const AVecType& a_vec,
415  const BVecType& b_vec,
416  bool_constant<post_nop_> = {}) const
417  {
418  // swap A and B
419  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
420  }
421 
422  // c_vec = a_vec * b_vec
423  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
424  {
425  // swap A and B
426  return Impl{}(b_vec, a_vec);
427  }
428 };
429 
430 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
432 {
434 
435  // swap A and B
436  using ADataType = typename Impl::BDataType;
437  using BDataType = typename Impl::ADataType;
438  using CDataType = typename Impl::CDataType;
439 
440  using AVecType =
442  using BVecType =
444  using CVecType = typename Impl::CVecType;
445 
446  static constexpr index_t kM = Impl::kN;
447  static constexpr index_t kN = Impl::kM;
448  static constexpr index_t kK = Impl::kK * kKIter;
449  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
450 
451  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
452 
453  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
454  "Multi-block on both M & N directions is not supported");
455 
456  CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
457  {
458  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
459  {
461  sequence<>,
466  sequence<2>,
467  sequence<1>>{};
468  }
469  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
470  {
471  // single block to multi-block thread mapping
473  sequence<>,
478  sequence<2>,
479  sequence<1>>{};
480  }
481  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
482  {
483  // each N blocks share the same data
490  sequence<2>,
491  sequence<1>>{};
492  }
493  }
494 
495  CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
496  {
497  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
498  {
500  sequence<>,
505  sequence<2>,
506  sequence<1>>{};
507  }
508  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
509  {
510  // each M blocks share the same data
517  sequence<2>,
518  sequence<1>>{};
519  }
520  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
521  {
522  // single block to multi-block thread mapping
524  sequence<>,
529  sequence<2>,
530  sequence<1>>{};
531  }
532  }
533 
534  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
535  {
536  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
537  {
539  sequence<>,
545  sequence<0, 2>>{};
546  }
547  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
548  {
550  sequence<>,
556  sequence<0, 2>>{};
557  }
558  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
559  {
561  sequence<>,
562  tuple<
568  sequence<0, 2>>{};
569  }
570  }
571 
573 
575 
577 
578  template <bool post_nop_ = false>
579  // c_vec += a_vec * b_vec
581  const AVecType& a_vec,
582  const BVecType& b_vec,
583  bool_constant<post_nop_> = {}) const
584  {
587  // swap A and B, value and type
588  static_for<0, kKIter, 1>{}([&](auto iKIter) {
589  Impl{}(c_vec,
590  reinterpret_cast<const buf_b&>(b_vec)
591  .template get_as<typename Impl::BVecType>()[iKIter],
592  reinterpret_cast<const buf_a&>(a_vec)
593  .template get_as<typename Impl::AVecType>()[iKIter],
594  bool_constant<post_nop_>{});
595  });
596  }
597 
598  template <index_t iKIter, bool post_nop_ = false>
599  // c_vec += a_vec * b_vec
601  const AVecType& a_vec,
602  const BVecType& b_vec,
604  bool_constant<post_nop_> = {}) const
605  {
608 
609  static_assert(iKIter < kKIter);
610  // swap A and B, value and type
611  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
612  Impl{}(c_vec,
613  reinterpret_cast<const buf_b&>(b_vec)
614  .template get_as<typename Impl::BVecType>()[iKIter],
615  reinterpret_cast<const buf_a&>(a_vec)
616  .template get_as<typename Impl::AVecType>()[iKIter],
617  bool_constant<post_nop_>{});
618  //});
619  }
620 
621  // c_vec = a_vec * b_vec
622  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
623  {
624  constexpr auto I0 = number<0>{};
627 
628  // swap A and B, value and type
629  auto c_vec = Impl{}(
630  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
631  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
632 
633  static_for<1, kKIter, 1>{}([&](auto iKIter) {
634  Impl{}(c_vec,
635  reinterpret_cast<const buf_b&>(b_vec)
636  .template get_as<typename Impl::BVecType>()[iKIter],
637  reinterpret_cast<const buf_a&>(a_vec)
638  .template get_as<typename Impl::AVecType>()[iKIter]);
639  });
640 
641  return c_vec;
642  }
643 };
644 
645 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
647 {
649 
650  // swap A and B
651  using ADataType = typename Impl::BDataType;
652  using BDataType = typename Impl::ADataType;
653  using CDataType = typename Impl::CDataType;
654 
655  using AVecType =
657  using BVecType =
659  using CVecType = typename Impl::CVecType;
660 
661  static constexpr index_t kM = Impl::kN;
662  static constexpr index_t kN = Impl::kM;
663  static constexpr index_t kK = Impl::kK * kKIter;
664  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
665  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
666 
667  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
668 
669  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
670  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
671 
673  sequence<>,
677  sequence<2>,
678  sequence<1>>;
679 #if 0
681  sequence<>,
682  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
683  Impl::kABKLane,
684  2,
685  Impl::kABKPerLane>,
689  sequence<2>,
690  sequence<1>>;
691 
693  sequence<>,
695  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
700 #else
701  // TODO: more test not only 32x32
703  sequence<>,
704  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
705  Impl::kCMLane,
706  SFactor,
707  Impl::kCM1PerLane>,
711  sequence<2>,
712  sequence<1>>;
713 
715  sequence<>,
717  sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
722 #endif
723  // c_vec += a_vec * b_vec
724  template <bool post_nop_ = false>
726  const AVecType& a_vec,
727  const BVecType& b_vec,
728  bool_constant<post_nop_> = {}) const
729  {
732  // swap A and B, value and type
733  static_for<0, kKIter, 1>{}([&](auto iKIter) {
734  Impl{}(c_vec,
735  reinterpret_cast<const buf_b&>(b_vec)
736  .template get_as<typename Impl::BVecType>()[iKIter],
737  reinterpret_cast<const buf_a&>(a_vec)
738  .template get_as<typename Impl::AVecType>()[iKIter],
739  bool_constant<post_nop_>{});
740  });
741  }
742 
743  template <index_t iKIter, bool post_nop_ = false>
745  const AVecType& a_vec,
746  const BVecType& b_vec,
748  bool_constant<post_nop_> = {}) const
749  {
752 
753  static_assert(iKIter < kKIter);
754  // swap A and B, value and type
755  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
756  Impl{}(c_vec,
757  reinterpret_cast<const buf_b&>(b_vec)
758  .template get_as<typename Impl::BVecType>()[iKIter],
759  reinterpret_cast<const buf_a&>(a_vec)
760  .template get_as<typename Impl::AVecType>()[iKIter],
761  bool_constant<post_nop_>{});
762  //});
763  }
764 
765  // c_vec = a_vec * b_vec
766  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
767  {
770  constexpr auto I0 = number<0>{};
771 
772  // swap A and B, value and type
773  auto c_vec = Impl{}(
774  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
775  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
776 
777  static_for<1, kKIter, 1>{}([&](auto iKIter) {
778  Impl{}(c_vec,
779  reinterpret_cast<const buf_b&>(b_vec)
780  .template get_as<typename Impl::BVecType>()[iKIter],
781  reinterpret_cast<const buf_a&>(a_vec)
782  .template get_as<typename Impl::AVecType>()[iKIter]);
783  });
784 
785  return c_vec;
786  }
787 };
788 
789 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
791 {
793 
794  using ADataType = typename Impl::ADataType;
795  using BDataType = typename Impl::BDataType;
796  using CDataType = typename Impl::CDataType;
797 
798  using AVecType =
800  using BVecType =
802  using CVecType = typename Impl::CVecType;
803 
804  static constexpr index_t kM = Impl::kM;
805  static constexpr index_t kN = Impl::kN;
806  static constexpr index_t kK = Impl::kK * kKIter;
807  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
808  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
809 
810  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
811 
812  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
813  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
814 
816  sequence<>,
817  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
818  Impl::kCMLane,
819  SFactor,
820  Impl::kCM1PerLane>,
824  sequence<2>,
825  sequence<1>>;
826 
828  sequence<>,
832  sequence<2>,
833  sequence<1>>;
834 
836  sequence<>,
837  tuple<sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>,
843 
844  // c_vec += a_vec * b_vec
845  template <bool post_nop_ = false>
847  const AVecType& a_vec,
848  const BVecType& b_vec,
849  bool_constant<post_nop_> = {}) const
850  {
853 
854  static_for<0, kKIter, 1>{}([&](auto iKIter) {
855  Impl{}(c_vec,
856  reinterpret_cast<const buf_a&>(a_vec)
857  .template get_as<typename Impl::AVecType>()[iKIter],
858  reinterpret_cast<const buf_b&>(b_vec)
859  .template get_as<typename Impl::BVecType>()[iKIter],
860  bool_constant<post_nop_>{});
861  });
862  }
863 
864  template <index_t iKIter, bool post_nop_ = false>
866  const AVecType& a_vec,
867  const BVecType& b_vec,
869  bool_constant<post_nop_> = {}) const
870  {
873 
874  static_assert(iKIter < kKIter);
875 
876  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
877  Impl{}(c_vec,
878  reinterpret_cast<const buf_a&>(a_vec)
879  .template get_as<typename Impl::AVecType>()[iKIter],
880  reinterpret_cast<const buf_b&>(b_vec)
881  .template get_as<typename Impl::BVecType>()[iKIter],
882  bool_constant<post_nop_>{});
883  //});
884  }
885 
886  // c_vec = a_vec * b_vec
887  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
888  {
889  constexpr auto I0 = number<0>{};
892 
893  auto c_vec = Impl{}(
894  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
895  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
896 
897  static_for<1, kKIter, 1>{}([&](auto iKIter) {
898  Impl{}(c_vec,
899  reinterpret_cast<const buf_a&>(a_vec)
900  .template get_as<typename Impl::AVecType>()[iKIter],
901  reinterpret_cast<const buf_b&>(b_vec)
902  .template get_as<typename Impl::BVecType>()[iKIter]);
903  });
904 
905  return c_vec;
906  }
907 };
908 
909 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
tuple_array< T, N > thread_buffer
Definition: thread_buffer.hpp:14
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:54
Definition: warp_gemm_attribute_mfma.hpp:13
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:14
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:29
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:17
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:16
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:22
typename Impl::AVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:20
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:70
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:18
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:26
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:25
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:24
typename Impl::BVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:21
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:27
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:61
Definition: warp_gemm_attribute_mfma.hpp:791
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:804
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:810
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:887
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:795
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:792
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:808
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:801
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:846
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:807
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:802
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:806
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:794
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:796
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:799
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:805
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:865
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:725
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:648
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:664
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:665
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:667
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:658
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:661
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:663
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:659
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:662
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:652
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:766
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:656
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:653
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:744
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:651
Definition: warp_gemm_attribute_mfma.hpp:432
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:572
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:443
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:438
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:444
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:622
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:449
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:534
static constexpr CK_TILE_DEVICE auto get_awarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:456
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:448
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:437
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:451
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:576
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:580
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:600
static constexpr CK_TILE_DEVICE auto get_bwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:495
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:441
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:436
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:574
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:447
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:446
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:433
Definition: warp_gemm_attribute_mfma.hpp:78
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:88
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:223
static constexpr CK_TILE_DEVICE auto get_awarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:103
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:268
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:219
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:90
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:93
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:84
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:181
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:227
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:81
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:246
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:96
static constexpr CK_TILE_DEVICE auto get_bwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:142
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:98
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:94
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:83
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:221
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:85
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:91
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:95
Definition: warp_gemm_attribute_mfma.hpp:361
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:377
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:365
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:375
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:413
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:369
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:364
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:366
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:372
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:362
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:370
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:374
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:373
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:368
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:423
Definition: warp_gemm_attribute_mfma.hpp:294
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:305
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:298
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:303
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:308
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:307
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:342
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:299
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:301
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:306
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:295
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:302
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:310
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:352
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:297
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192