/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp Source File
warp_gemm_attribute_wmma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 // TODO: currently only support 16 bit input, which means only support tr16_b128; will use ADataType
13 // to determine the layout in the future
14 template <typename Impl>
16 {
23  typename Impl::kABYs2RHsMajor,
24  typename Impl::kABYs2RHsMinor>;
25 };
26 
27 template <typename Impl>
29 {
36  typename Impl::kABYs2RHsMajor,
37  typename Impl::kABYs2RHsMinor>;
38 };
39 
40 template <typename Impl>
42 {
44  sequence<>,
49  typename Impl::kCYs2RHsMajor,
50  typename Impl::kCYs2RHsMinor>;
51 };
52 
53 template <typename WarpGemmAttributeWmmaImpl_, bool kTransC = false>
55 {
57 
58  using ADataType = typename Impl::ADataType;
59  using BDataType = typename Impl::BDataType;
60  using CDataType = typename Impl::CDataType;
61 
62  using AVecType = typename Impl::AVecType;
63  using BVecType = typename Impl::BVecType;
64  using CVecType = typename Impl::CVecType;
65 
66  static constexpr index_t kM = Impl::kM;
67  static constexpr index_t kN = Impl::kN;
68  static constexpr index_t kK = Impl::kK;
69  static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane;
70 
71  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
72 
73  // 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2
74  // 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4
77 
78  // kCM0PerLane = 4, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16 for 16 bit input
79  // kCM0PerLane = 2, kCMLane = 2, kCM1PerLane = 4, kCNLane = 16 for 8 bit input
81 
82  // c_vec += a_vec * b_vec
83  template <bool post_nop_ = false>
85  const AVecType& a_vec,
86  const BVecType& b_vec,
87  bool_constant<post_nop_> = {}) const
88  {
89  if constexpr(kTransC)
90  {
91  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
92  }
93  else
94  {
95  Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
96  }
97  }
98 
99  // c_vec = a_vec * b_vec
100  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
101  {
102  if constexpr(kTransC)
103  {
104  return Impl{}(b_vec, a_vec);
105  }
106  else
107  {
108  return Impl{}(a_vec, b_vec);
109  }
110  }
111 };
112 
113 template <typename ADataType,
114  typename BDataType,
115  typename AccDataType,
116  index_t M_Warp_Tile,
117  index_t N_Warp_Tile,
118  index_t K_Warp_Tile>
120 {
121  if(is_gfx12_supported())
122  {
123  return has_wmma_traits_v<gfx12_t,
124  ADataType,
125  BDataType,
126  AccDataType,
127  M_Warp_Tile,
128  N_Warp_Tile,
129  K_Warp_Tile>;
130  }
131  else if(is_gfx11_supported())
132  {
133  return has_wmma_traits_v<gfx11_t,
134  ADataType,
135  BDataType,
136  AccDataType,
137  M_Warp_Tile,
138  N_Warp_Tile,
139  K_Warp_Tile>;
140  }
141  else
142  {
143  return false;
144  }
145 }
146 
147 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr bool has_wmma_traits_v
Definition: warp_gemm_attribute_wmma_impl.hpp:133
CK_TILE_HOST bool check_wmma_supported()
Definition: warp_gemm_attribute_wmma.hpp:119
bool is_gfx12_supported()
Definition: device_prop.hpp:63
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
bool is_gfx11_supported()
Definition: device_prop.hpp:55
Definition: warp_gemm_attribute_wmma.hpp:16
Definition: warp_gemm_attribute_wmma.hpp:29
Definition: warp_gemm_attribute_wmma.hpp:42
Definition: warp_gemm_attribute_wmma.hpp:55
remove_cvref_t< WarpGemmAttributeWmmaImpl_ > Impl
Definition: warp_gemm_attribute_wmma.hpp:56
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_wmma.hpp:84
static constexpr index_t kN
Definition: warp_gemm_attribute_wmma.hpp:67
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_wmma.hpp:69
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_wmma.hpp:58
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_wmma.hpp:60
typename CWarpDstrEncodingTrait< Impl >::type CWarpDstrEncoding
Definition: warp_gemm_attribute_wmma.hpp:80
typename Impl::BVecType BVecType
Definition: warp_gemm_attribute_wmma.hpp:63
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_wmma.hpp:71
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_wmma.hpp:59
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_wmma.hpp:64
typename Impl::AVecType AVecType
Definition: warp_gemm_attribute_wmma.hpp:62
static constexpr index_t kK
Definition: warp_gemm_attribute_wmma.hpp:68
static constexpr index_t kM
Definition: warp_gemm_attribute_wmma.hpp:66
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_wmma.hpp:100
typename AWarpDstrEncodingTrait< Impl >::type AWarpDstrEncoding
Definition: warp_gemm_attribute_wmma.hpp:75
typename BWarpDstrEncodingTrait< Impl >::type BWarpDstrEncoding
Definition: warp_gemm_attribute_wmma.hpp:76
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192
Definition: arch.hpp:262
Definition: arch.hpp:265