/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp Source File
warp_gemm_attribute_wmma_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // Base traits for WMMA operations
12 template <typename Arch,
13  typename AType,
14  typename BType,
15  typename CType,
16  index_t M,
17  index_t N,
18  index_t K>
19 struct WmmaTraits;
20 
21 // Generic WMMA implementation using traits
22 template <typename Traits>
24 {
25  using ADataType = typename Traits::ADataType;
26  using BDataType = typename Traits::BDataType;
27  using CDataType = typename Traits::CDataType;
28 
29  using AVecType = typename Traits::AVecType;
30  using BVecType = typename Traits::BVecType;
31  using CVecType = typename Traits::CVecType;
32 
33  // Forward all static constants and type aliases
34  static constexpr index_t kM = Traits::kM;
35  static constexpr index_t kN = Traits::kN;
36  static constexpr index_t kK = Traits::kK;
37 
38  static constexpr index_t kAMBlock = Traits::kAMBlock;
39  static constexpr index_t kBNBlock = Traits::kBNBlock;
40 
41  static constexpr index_t kRepeat = Traits::kRepeat;
42  static constexpr index_t kAMLane = Traits::kAMLane;
43  static constexpr index_t kBNLane = Traits::kBNLane;
44  static constexpr index_t kABK0PerLane = Traits::kABK0PerLane;
45  static constexpr index_t kABKLane = Traits::kABKLane;
46  static constexpr index_t kABK1PerLane = Traits::kABK1PerLane;
47 
48  static constexpr index_t kCMLane = Traits::kCMLane;
49  static constexpr index_t kCNLane = Traits::kCNLane;
50  static constexpr index_t kCM0PerLane = Traits::kCM0PerLane;
51  static constexpr index_t kCM1PerLane = Traits::kCM1PerLane;
52 
53  using kABPs2RHssMajor = typename Traits::kABPs2RHssMajor;
54  using kABPs2RHssMinor = typename Traits::kABPs2RHssMinor;
55  using kABYs2RHsMajor = typename Traits::kABYs2RHsMajor;
56  using kABYs2RHsMinor = typename Traits::kABYs2RHsMinor;
57 
58  using kCPs2RHssMajor = typename Traits::kCPs2RHssMajor;
59  using kCPs2RHssMinor = typename Traits::kCPs2RHssMinor;
60  using kCYs2RHsMajor = typename Traits::kCYs2RHsMajor;
61  using kCYs2RHsMinor = typename Traits::kCYs2RHsMinor;
62 
63  using kCTPs2RHssMajor = typename Traits::kCTPs2RHssMajor;
64  using kCTPs2RHssMinor = typename Traits::kCTPs2RHssMinor;
65  using kCTYs2RHsMajor = typename Traits::kCTYs2RHsMajor;
66  using kCTYs2RHsMinor = typename Traits::kCTYs2RHsMinor;
67 
68  // c_vec += a_vec * b_vec
69  template <bool clamp = false, bool post_nop_ = false>
71  const AVecType& a_vec,
72  const BVecType& b_vec,
73  bool_constant<post_nop_> = {}) const
74  {
75  c_vec = Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, c_vec);
76  }
77 
78  // c_vec = a_vec * b_vec
79  template <bool clamp = false>
80  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
81  {
82  return bit_cast<CVecType>(
83  Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, CVecType{0.f}));
84  }
85 };
86 
87 using DeviceIp = remove_cvref_t<decltype(ck_tile::get_device_arch())>;
90 
93 
96 
99 
102 
105 
108 
109 template <typename Arch,
110  typename AType,
111  typename BType,
112  typename CType,
113  index_t warp_m,
114  index_t warp_n,
115  index_t warp_k>
117 {
118  template <typename T>
119  static auto
120  test(int) -> decltype(std::declval<
122  ADataType>(),
123  std::true_type{});
124 
125  template <typename>
126  static std::false_type test(...);
127 
128  static constexpr bool value = decltype(test<Arch>(0))::value;
129 };
130 
131 template <typename Arch,
132  typename AType,
133  typename BType,
134  typename CType,
135  index_t warp_m,
136  index_t warp_n,
137  index_t warp_k>
138 constexpr bool has_wmma_traits_v =
140 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr bool has_wmma_traits_v
Definition: warp_gemm_attribute_wmma_impl.hpp:138
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
remove_cvref_t< decltype(ck_tile::get_device_arch())> DeviceIp
Definition: warp_gemm_attribute_wmma_impl.hpp:87
bool_constant< false > false_type
Definition: integral_constant.hpp:63
bool_constant< true > true_type
Definition: integral_constant.hpp:62
Definition: warp_gemm_attribute_wmma_impl.hpp:24
static constexpr index_t kK
Definition: warp_gemm_attribute_wmma_impl.hpp:36
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_wmma_impl.hpp:80
typename Traits::CVecType CVecType
Definition: warp_gemm_attribute_wmma_impl.hpp:31
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_wmma_impl.hpp:42
typename Traits::kCPs2RHssMajor kCPs2RHssMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:58
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_wmma_impl.hpp:39
typename Traits::BVecType BVecType
Definition: warp_gemm_attribute_wmma_impl.hpp:30
typename Traits::kABYs2RHsMinor kABYs2RHsMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:56
typename Traits::kCYs2RHsMajor kCYs2RHsMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:60
typename Traits::kABPs2RHssMinor kABPs2RHssMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:54
typename Traits::BDataType BDataType
Definition: warp_gemm_attribute_wmma_impl.hpp:26
typename Traits::kCTYs2RHsMajor kCTYs2RHsMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:65
typename Traits::AVecType AVecType
Definition: warp_gemm_attribute_wmma_impl.hpp:29
typename Traits::kABPs2RHssMajor kABPs2RHssMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:53
typename Traits::kCTYs2RHsMinor kCTYs2RHsMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:66
typename Traits::kCYs2RHsMinor kCYs2RHsMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:61
typename Traits::kCTPs2RHssMinor kCTPs2RHssMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:64
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_wmma_impl.hpp:38
static constexpr index_t kM
Definition: warp_gemm_attribute_wmma_impl.hpp:34
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_wmma_impl.hpp:49
typename Traits::kCPs2RHssMinor kCPs2RHssMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:59
static constexpr index_t kABK0PerLane
Definition: warp_gemm_attribute_wmma_impl.hpp:44
typename Traits::kCTPs2RHssMajor kCTPs2RHssMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:63
static constexpr index_t kN
Definition: warp_gemm_attribute_wmma_impl.hpp:35
static constexpr index_t kRepeat
Definition: warp_gemm_attribute_wmma_impl.hpp:41
typename Traits::CDataType CDataType
Definition: warp_gemm_attribute_wmma_impl.hpp:27
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_wmma_impl.hpp:50
typename Traits::kABYs2RHsMajor kABYs2RHsMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:55
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_wmma_impl.hpp:43
typename Traits::ADataType ADataType
Definition: warp_gemm_attribute_wmma_impl.hpp:25
static constexpr index_t kABK1PerLane
Definition: warp_gemm_attribute_wmma_impl.hpp:46
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_wmma_impl.hpp:48
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_wmma_impl.hpp:70
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_wmma_impl.hpp:51
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_wmma_impl.hpp:45
Definition: warp_gemm_attribute_wmma_impl.hpp:19
Definition: integral_constant.hpp:13
Definition: warp_gemm_attribute_wmma_impl.hpp:117
static constexpr bool value
Definition: warp_gemm_attribute_wmma_impl.hpp:128
static auto test(int) -> decltype(std::declval< typename WmmaTraits< T, AType, BType, CType, warp_m, warp_n, warp_k >::ADataType >(), std::true_type{})
static std::false_type test(...)