/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp Source File
warp_gemm_attribute_wmma_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // Base traits for WMMA operations
12 template <typename Arch,
13  typename AType,
14  typename BType,
15  typename CType,
16  index_t M,
17  index_t N,
18  index_t K>
19 struct WmmaTraits;
20 
21 // Generic WMMA implementation using traits
22 template <typename Traits>
24 {
25  using ADataType = typename Traits::ADataType;
26  using BDataType = typename Traits::BDataType;
27  using CDataType = typename Traits::CDataType;
28 
29  using AVecType = typename Traits::AVecType;
30  using BVecType = typename Traits::BVecType;
31  using CVecType = typename Traits::CVecType;
32 
33  // Forward all static constants and type aliases
34  static constexpr index_t kM = Traits::kM;
35  static constexpr index_t kN = Traits::kN;
36  static constexpr index_t kK = Traits::kK;
37 
38  static constexpr index_t kAMBlock = Traits::kAMBlock;
39  static constexpr index_t kBNBlock = Traits::kBNBlock;
40 
41  static constexpr index_t kRepeat = Traits::kRepeat;
42  static constexpr index_t kAMLane = Traits::kAMLane;
43  static constexpr index_t kBNLane = Traits::kBNLane;
44  static constexpr index_t kABK0PerLane = Traits::kABK0PerLane;
45  static constexpr index_t kABKLane = Traits::kABKLane;
46  static constexpr index_t kABK1PerLane = Traits::kABK1PerLane;
47 
48  static constexpr index_t kCMLane = Traits::kCMLane;
49  static constexpr index_t kCNLane = Traits::kCNLane;
50  static constexpr index_t kCM0PerLane = Traits::kCM0PerLane;
51  static constexpr index_t kCM1PerLane = Traits::kCM1PerLane;
52 
53  using kABPs2RHssMajor = typename Traits::kABPs2RHssMajor;
54  using kABPs2RHssMinor = typename Traits::kABPs2RHssMinor;
55  using kABYs2RHsMajor = typename Traits::kABYs2RHsMajor;
56  using kABYs2RHsMinor = typename Traits::kABYs2RHsMinor;
57 
58  using kCPs2RHssMajor = typename Traits::kCPs2RHssMajor;
59  using kCPs2RHssMinor = typename Traits::kCPs2RHssMinor;
60  using kCYs2RHsMajor = typename Traits::kCYs2RHsMajor;
61  using kCYs2RHsMinor = typename Traits::kCYs2RHsMinor;
62 
63  // c_vec += a_vec * b_vec
64  template <bool clamp = false, bool post_nop_ = false>
66  const AVecType& a_vec,
67  const BVecType& b_vec,
68  bool_constant<post_nop_> = {}) const
69  {
70  c_vec = Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, c_vec);
71  }
72 
73  // c_vec = a_vec * b_vec
74  template <bool clamp = false>
75  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
76  {
77  return bit_cast<CVecType>(
78  Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, CVecType{0.f}));
79  }
80 };
81 
82 using DeviceIp = remove_cvref_t<decltype(ck_tile::get_device_arch())>;
85 
88 
91 
94 
97 
100 
103 
104 template <typename Arch,
105  typename AType,
106  typename BType,
107  typename CType,
108  index_t warp_m,
109  index_t warp_n,
110  index_t warp_k>
112 {
113  template <typename T>
114  static auto
115  test(int) -> decltype(std::declval<
117  ADataType>(),
118  std::true_type{});
119 
120  template <typename>
121  static std::false_type test(...);
122 
123  static constexpr bool value = decltype(test<Arch>(0))::value;
124 };
125 
126 template <typename Arch,
127  typename AType,
128  typename BType,
129  typename CType,
130  index_t warp_m,
131  index_t warp_n,
132  index_t warp_k>
133 constexpr bool has_wmma_traits_v =
135 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr bool has_wmma_traits_v
Definition: warp_gemm_attribute_wmma_impl.hpp:133
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
remove_cvref_t< decltype(ck_tile::get_device_arch())> DeviceIp
Definition: warp_gemm_attribute_wmma_impl.hpp:82
bool_constant< false > false_type
Definition: integral_constant.hpp:63
bool_constant< true > true_type
Definition: integral_constant.hpp:62
Definition: warp_gemm_attribute_wmma_impl.hpp:24
static constexpr index_t kK
Definition: warp_gemm_attribute_wmma_impl.hpp:36
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_wmma_impl.hpp:75
typename Traits::CVecType CVecType
Definition: warp_gemm_attribute_wmma_impl.hpp:31
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_wmma_impl.hpp:42
typename Traits::kCPs2RHssMajor kCPs2RHssMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:58
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_wmma_impl.hpp:39
typename Traits::BVecType BVecType
Definition: warp_gemm_attribute_wmma_impl.hpp:30
typename Traits::kABYs2RHsMinor kABYs2RHsMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:56
typename Traits::kCYs2RHsMajor kCYs2RHsMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:60
typename Traits::kABPs2RHssMinor kABPs2RHssMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:54
typename Traits::BDataType BDataType
Definition: warp_gemm_attribute_wmma_impl.hpp:26
typename Traits::AVecType AVecType
Definition: warp_gemm_attribute_wmma_impl.hpp:29
typename Traits::kABPs2RHssMajor kABPs2RHssMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:53
typename Traits::kCYs2RHsMinor kCYs2RHsMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:61
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_wmma_impl.hpp:38
static constexpr index_t kM
Definition: warp_gemm_attribute_wmma_impl.hpp:34
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_wmma_impl.hpp:49
typename Traits::kCPs2RHssMinor kCPs2RHssMinor
Definition: warp_gemm_attribute_wmma_impl.hpp:59
static constexpr index_t kABK0PerLane
Definition: warp_gemm_attribute_wmma_impl.hpp:44
static constexpr index_t kN
Definition: warp_gemm_attribute_wmma_impl.hpp:35
static constexpr index_t kRepeat
Definition: warp_gemm_attribute_wmma_impl.hpp:41
typename Traits::CDataType CDataType
Definition: warp_gemm_attribute_wmma_impl.hpp:27
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_wmma_impl.hpp:50
typename Traits::kABYs2RHsMajor kABYs2RHsMajor
Definition: warp_gemm_attribute_wmma_impl.hpp:55
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_wmma_impl.hpp:43
typename Traits::ADataType ADataType
Definition: warp_gemm_attribute_wmma_impl.hpp:25
static constexpr index_t kABK1PerLane
Definition: warp_gemm_attribute_wmma_impl.hpp:46
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_wmma_impl.hpp:48
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_wmma_impl.hpp:65
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_wmma_impl.hpp:51
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_wmma_impl.hpp:45
Definition: warp_gemm_attribute_wmma_impl.hpp:19
Definition: integral_constant.hpp:13
Definition: warp_gemm_attribute_wmma_impl.hpp:112
static constexpr bool value
Definition: warp_gemm_attribute_wmma_impl.hpp:123
static auto test(int) -> decltype(std::declval< typename WmmaTraits< T, AType, BType, CType, warp_m, warp_n, warp_k >::ADataType >(), std::true_type{})
static std::false_type test(...)