/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp Source File
warp_gemm_attribute_wmma_impl.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // Base traits for WMMA operations
12 template <typename Arch,
13  typename AType,
14  typename BType,
15  typename CType,
16  index_t M,
17  index_t N,
18  index_t K>
19 struct WmmaTraits;
20 
21 // Generic WMMA implementation using traits
22 template <typename Traits>
24 {
25  using TraitsType = Traits;
26  using ADataType = typename Traits::ADataType;
27  using BDataType = typename Traits::BDataType;
28  using CDataType = typename Traits::CDataType;
29 
30  using AVecType = typename Traits::AVecType;
31  using BVecType = typename Traits::BVecType;
32  using CVecType = typename Traits::CVecType;
33 
34  // Forward all static constants and type aliases
35  static constexpr index_t kM = Traits::kM;
36  static constexpr index_t kN = Traits::kN;
37  static constexpr index_t kK = Traits::kK;
38 
39  static constexpr index_t kAMBlock = Traits::kAMBlock;
40  static constexpr index_t kBNBlock = Traits::kBNBlock;
41 
42  static constexpr index_t kRepeat = Traits::kRepeat;
43  static constexpr index_t kAMLane = Traits::kAMLane;
44  static constexpr index_t kBNLane = Traits::kBNLane;
45  static constexpr index_t kABK0PerLane = Traits::kABK0PerLane;
46  static constexpr index_t kABKLane = Traits::kABKLane;
47  static constexpr index_t kABK1PerLane = Traits::kABK1PerLane;
48 
49  static constexpr index_t kCMLane = Traits::kCMLane;
50  static constexpr index_t kCNLane = Traits::kCNLane;
51  static constexpr index_t kCM0PerLane = Traits::kCM0PerLane;
52  static constexpr index_t kCM1PerLane = Traits::kCM1PerLane;
53 
54  using kABPs2RHssMajor = typename Traits::kABPs2RHssMajor;
55  using kABPs2RHssMinor = typename Traits::kABPs2RHssMinor;
56  using kABYs2RHsMajor = typename Traits::kABYs2RHsMajor;
57  using kABYs2RHsMinor = typename Traits::kABYs2RHsMinor;
58 
59  using kCPs2RHssMajor = typename Traits::kCPs2RHssMajor;
60  using kCPs2RHssMinor = typename Traits::kCPs2RHssMinor;
61  using kCYs2RHsMajor = typename Traits::kCYs2RHsMajor;
62  using kCYs2RHsMinor = typename Traits::kCYs2RHsMinor;
63 
64  using kCTPs2RHssMajor = typename Traits::kCTPs2RHssMajor;
65  using kCTPs2RHssMinor = typename Traits::kCTPs2RHssMinor;
66  using kCTYs2RHsMajor = typename Traits::kCTYs2RHsMajor;
67  using kCTYs2RHsMinor = typename Traits::kCTYs2RHsMinor;
68 
69  // c_vec += a_vec * b_vec
70  template <bool clamp = false, bool post_nop_ = false>
72  const AVecType& a_vec,
73  const BVecType& b_vec,
74  bool_constant<post_nop_> = {}) const
75  {
76  c_vec = Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, c_vec);
77  }
78 
79  // c_vec = a_vec * b_vec
80  template <bool clamp = false>
81  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
82  {
83  return bit_cast<CVecType>(
84  Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, CVecType{0.f}));
85  }
86 };
87 
88 using DeviceIp = remove_cvref_t<decltype(ck_tile::get_device_arch())>;
91 
94 
97 
100 
103 
106 
109 
110 template <typename Arch,
111  typename AType,
112  typename BType,
113  typename CType,
114  index_t warp_m,
115  index_t warp_n,
116  index_t warp_k>
118 {
119  template <typename T>
120  static auto
121  test(int) -> decltype(std::declval<
123  ADataType>(),
124  std::true_type{});
125 
126  template <typename>
127  static std::false_type test(...);
128 
129  static constexpr bool value = decltype(test<Arch>(0))::value;
130 };
131 
132 template <typename Arch,
133  typename AType,
134  typename BType,
135  typename CType,
136  index_t warp_m,
137  index_t warp_n,
138  index_t warp_k>
139 constexpr bool has_wmma_traits_v =
141 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
Definition: cluster_descriptor.hpp:13
constexpr bool has_wmma_traits_v
Definition: warp_gemm_attribute_wmma_impl.hpp:139
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
remove_cvref_t< decltype(ck_tile::get_device_arch())> DeviceIp
Definition: warp_gemm_attribute_wmma_impl.hpp:88
bool_constant< false > false_type
Definition: integral_constant.hpp:63
bool_constant< true > true_type
Definition: integral_constant.hpp:62
Definition: warp_gemm_attribute_wmma_impl.hpp:24
static constexpr index_t kK
Definition: warp_gemm_attribute_wmma_impl.hpp:37
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_wmma_impl.hpp:81
typename Traits::CVecType CVecType
Definition: warp_gemm_attribute_wmma_impl.hpp:32
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_wmma_impl.hpp:43
typename Traits::kCPs2RHssMajor kCPs2RHssMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:59
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_wmma_impl.hpp:40
typename Traits::BVecType BVecType
Definition: warp_gemm_attribute_wmma_impl.hpp:31
typename Traits::kABYs2RHsMinor kABYs2RHsMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:57
typename Traits::kCYs2RHsMajor kCYs2RHsMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:61
typename Traits::kABPs2RHssMinor kABPs2RHssMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:55
typename Traits::BDataType BDataType
Definition: warp_gemm_attribute_wmma_impl.hpp:27
typename Traits::kCTYs2RHsMajor kCTYs2RHsMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:66
typename Traits::AVecType AVecType
Definition: warp_gemm_attribute_wmma_impl.hpp:30
typename Traits::kABPs2RHssMajor kABPs2RHssMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:54
typename Traits::kCTYs2RHsMinor kCTYs2RHsMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:67
typename Traits::kCYs2RHsMinor kCYs2RHsMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:62
typename Traits::kCTPs2RHssMinor kCTPs2RHssMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:65
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_wmma_impl.hpp:39
static constexpr index_t kM
Definition: warp_gemm_attribute_wmma_impl.hpp:35
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_wmma_impl.hpp:50
typename Traits::kCPs2RHssMinor kCPs2RHssMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:60
static constexpr index_t kABK0PerLane
Definition: warp_gemm_attribute_wmma_impl.hpp:45
typename Traits::kCTPs2RHssMajor kCTPs2RHssMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:64
static constexpr index_t kN
Definition: warp_gemm_attribute_wmma_impl.hpp:36
static constexpr index_t kRepeat
Definition: warp_gemm_attribute_wmma_impl.hpp:42
typename Traits::CDataType CDataType
Definition: warp_gemm_attribute_wmma_impl.hpp:28
Traits TraitsType
Definition: warp_gemm_attribute_wmma_impl.hpp:25
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_wmma_impl.hpp:51
typename Traits::kABYs2RHsMajor kABYs2RHsMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:56
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_wmma_impl.hpp:44
typename Traits::ADataType ADataType
Definition: warp_gemm_attribute_wmma_impl.hpp:26
static constexpr index_t kABK1PerLane
Definition: warp_gemm_attribute_wmma_impl.hpp:47
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_wmma_impl.hpp:49
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_wmma_impl.hpp:71
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_wmma_impl.hpp:52
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_wmma_impl.hpp:46
Definition: warp_gemm_attribute_wmma_impl.hpp:19
Definition: integral_constant.hpp:13
Definition: warp_gemm_attribute_wmma_impl.hpp:118
static constexpr bool value
Definition: warp_gemm_attribute_wmma_impl.hpp:129
static auto test(int) -> decltype(std::declval< typename WmmaTraits< T, AType, BType, CType, warp_m, warp_n, warp_k >::ADataType >(), std::true_type{})
static std::false_type test(...)