/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp Source File
warp_gemm_attribute_wmma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 // TODO: currently only support 16 bit input, which means only support tr16_b128; will use ADataType
13 // to determine the layout in the future
14 template <typename Impl>
16 {
23  typename Impl::kABYs2RHsMajor,
24  typename Impl::kABYs2RHsMinor>;
25 };
26 
27 template <typename Impl>
29 {
36  typename Impl::kABYs2RHsMajor,
37  typename Impl::kABYs2RHsMinor>;
38 };
39 
40 template <typename Impl>
42 {
44  sequence<>,
49  typename Impl::kCYs2RHsMajor,
50  typename Impl::kCYs2RHsMinor>;
51 };
52 
53 template <typename Impl>
55 {
57  sequence<>,
62  typename Impl::kCTYs2RHsMajor,
63  typename Impl::kCTYs2RHsMinor>;
64 };
65 
66 template <typename WarpGemmAttributeWmmaImpl_, bool kTransC = false>
68 {
70 
71  using ADataType = typename Impl::ADataType;
72  using BDataType = typename Impl::BDataType;
73  using CDataType = typename Impl::CDataType;
74 
75  using AVecType = typename Impl::AVecType;
76  using BVecType = typename Impl::BVecType;
77  using CVecType = typename Impl::CVecType;
78 
79  static constexpr index_t kM = Impl::kM;
80  static constexpr index_t kN = Impl::kN;
81  static constexpr index_t kK = Impl::kK;
82  static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane;
83 
84  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
85 
86  // 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2
87  // 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4
90 
91  // kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16
93  std::conditional_t<kTransC,
96 
97  // c_vec += a_vec * b_vec
98  template <bool post_nop_ = false>
100  const AVecType& a_vec,
101  const BVecType& b_vec,
102  bool_constant<post_nop_> = {}) const
103  {
104  if constexpr(kTransC)
105  {
106  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
107  }
108  else
109  {
110  Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
111  }
112  }
113 
114  // c_vec = a_vec * b_vec
115  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
116  {
117  if constexpr(kTransC)
118  {
119  return Impl{}(b_vec, a_vec);
120  }
121  else
122  {
123  return Impl{}(a_vec, b_vec);
124  }
125  }
126 };
127 
128 template <typename ADataType,
129  typename BDataType,
130  typename AccDataType,
131  index_t M_Warp_Tile,
132  index_t N_Warp_Tile,
133  index_t K_Warp_Tile>
135 {
136  if(is_gfx12_supported())
137  {
138  return has_wmma_traits_v<gfx12_t,
139  ADataType,
140  BDataType,
141  AccDataType,
142  M_Warp_Tile,
143  N_Warp_Tile,
144  K_Warp_Tile>;
145  }
146  else if(is_gfx11_supported())
147  {
148  return has_wmma_traits_v<gfx11_t,
149  ADataType,
150  BDataType,
151  AccDataType,
152  M_Warp_Tile,
153  N_Warp_Tile,
154  K_Warp_Tile>;
155  }
156  else
157  {
158  return false;
159  }
160 }
161 
162 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr bool has_wmma_traits_v
Definition: warp_gemm_attribute_wmma_impl.hpp:138
CK_TILE_HOST bool check_wmma_supported()
Definition: warp_gemm_attribute_wmma.hpp:134
bool is_gfx12_supported()
Definition: device_prop.hpp:63
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
bool is_gfx11_supported()
Definition: device_prop.hpp:55
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: warp_gemm_attribute_wmma.hpp:16
Definition: warp_gemm_attribute_wmma.hpp:29
Definition: warp_gemm_attribute_wmma.hpp:55
Definition: warp_gemm_attribute_wmma.hpp:42
Definition: warp_gemm_attribute_wmma.hpp:68
remove_cvref_t< WarpGemmAttributeWmmaImpl_ > Impl
Definition: warp_gemm_attribute_wmma.hpp:69
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_wmma.hpp:99
std::conditional_t< kTransC, typename CTransposedWarpDstrEncodingTrait< Impl >::type, typename CWarpDstrEncodingTrait< Impl >::type > CWarpDstrEncoding
Definition: warp_gemm_attribute_wmma.hpp:95
static constexpr index_t kN
Definition: warp_gemm_attribute_wmma.hpp:80
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_wmma.hpp:82
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_wmma.hpp:71
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_wmma.hpp:73
typename Impl::BVecType BVecType
Definition: warp_gemm_attribute_wmma.hpp:76
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_wmma.hpp:84
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_wmma.hpp:72
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_wmma.hpp:77
typename Impl::AVecType AVecType
Definition: warp_gemm_attribute_wmma.hpp:75
static constexpr index_t kK
Definition: warp_gemm_attribute_wmma.hpp:81
static constexpr index_t kM
Definition: warp_gemm_attribute_wmma.hpp:79
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_wmma.hpp:115
typename AWarpDstrEncodingTrait< Impl >::type AWarpDstrEncoding
Definition: warp_gemm_attribute_wmma.hpp:88
typename BWarpDstrEncodingTrait< Impl >::type BWarpDstrEncoding
Definition: warp_gemm_attribute_wmma.hpp:89
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192
Definition: arch.hpp:326
Definition: arch.hpp:329