/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_base.hpp Source File
device_base.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
7 #include <string>
8 #include <sstream>
9 #include <regex>
10 #include <optional>
11 
12 #include "ck/stream_config.hpp"
13 #endif
14 #include "ck/utility/get_id.hpp"
15 
16 namespace ck {
17 namespace tensor_operation {
18 namespace device {
19 
20 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
21 #define GET_OBJECT_NAME_IMLP \
22  std::optional<std::string> GetObjectName() const override \
23  { \
24  std::string str = __PRETTY_FUNCTION__; \
25  static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
26  std::smatch match; \
27  if(!std::regex_search(str, match, obj_name_expr)) \
28  { \
29  return str; \
30  } \
31  return std::string(match[1]) + ';'; \
32  }
33 
34 #define GET_TEMPLATE_INFO_IMPL \
35  std::optional<std::string> GetTemplateInfo() const override \
36  { \
37  std::string str = __PRETTY_FUNCTION__; \
38  static std::regex template_expr{"\\[(.*)\\]"}; \
39  std::smatch match; \
40  if(!std::regex_search(str, match, template_expr)) \
41  { \
42  return std::nullopt; \
43  } \
44  return std::string(match[1]); \
45  }
46 
47 #define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
48 #endif
49 
50 template <index_t BlockSize_,
51  index_t MPerBlock_,
52  index_t NPerBlock_,
53  index_t MPerXDL_,
54  index_t NPerXDL_,
55  index_t MXdlPerWave_,
56  bool IsWave64>
57 static constexpr auto GetNXdlPerWave2()
58 {
59  constexpr index_t Waves = IsWave64 ? BlockSize_ / 64 : BlockSize_ / 32;
60  constexpr index_t MWaves = MPerBlock_ / (MXdlPerWave_ * MPerXDL_);
61  static_assert(MWaves > 0);
62 
63  constexpr index_t NWaves = Waves / MWaves;
64  if constexpr(NWaves == 0)
65  {
66  return 0;
67  }
68  else
69  {
70  if constexpr(NPerBlock_ % (NPerXDL_ * NWaves) == 0)
71  {
72  return NPerBlock_ / (NWaves * NPerXDL_);
73  }
74  else
75  {
76  return 0;
77  }
78  }
79 }
80 
81 #define GET_NXDL_PER_WAVE_IMPL \
82  template <bool IsWave64> \
83  static constexpr auto GetNXdlPerWave() \
84  { \
85  return GetNXdlPerWave2<BlockSize, \
86  MPerBlock, \
87  NPerBlock, \
88  MPerXDL, \
89  NPerXDL, \
90  MXdlPerWave, \
91  IsWave64>(); \
92  }
93 
94 #define INVOKER_RUN_IMPL \
95  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
96  { \
97  if(get_warp_size() == 64) \
98  { \
99  if constexpr(NXdlPerWave64 > 0) \
100  { \
101  return RunImp<GridwiseGemm64>(arg, stream_config); \
102  } \
103  } \
104  else \
105  { \
106  if constexpr(NXdlPerWave32 > 0) \
107  { \
108  return RunImp<GridwiseGemm32>(arg, stream_config); \
109  } \
110  } \
111  return 0; \
112  }
113 
114 #define INVOKER_RUN3_IMPL \
115  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
116  { \
117  if(get_warp_size() == 64) \
118  { \
119  if constexpr(NXdlPerWave64 > 0) \
120  { \
121  return RunImp<GridwiseGemm64>(arg, stream_config); \
122  } \
123  } \
124  else \
125  { \
126  if constexpr(NXdlPerWave32 > 0) \
127  { \
128  return RunImp<GridwiseGemm32>( \
129  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg), \
130  stream_config); \
131  } \
132  } \
133  return 0; \
134  }
135 
136 template <index_t BlockSize,
137  index_t MPerBlock,
138  index_t NPerBlock,
139  index_t MPerXdl,
140  index_t NPerXdl,
141  index_t MXdlPerWave,
142  index_t NXdlPerWave,
143  typename CDataType,
144  InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
145 __device__ static bool constexpr IsValidGemmCompilationParameter()
146 {
147 #if defined(__gfx11__) || defined(__gfx12__)
148  if constexpr(MPerXdl != 16 || NPerXdl != 16)
149  {
150  return false;
151  }
152 #endif
153 
154 #if defined(__gfx11__)
155  constexpr bool SupportMemOp = CGlobalMemoryDataOperation_ == InMemoryDataOperationEnum::Set;
156 #else
157  constexpr bool SupportMemOp =
158  sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation_ == InMemoryDataOperationEnum::Set);
159 #endif
160  if constexpr(SupportMemOp == false)
161  {
162  return false;
163  }
164 
165  if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0)
166  {
167  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
168  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
169  if constexpr(MWaves > 0 && NWaves > 0)
170  {
171  constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
172  return WaveSize == get_warp_size();
173  }
174  }
175  return false;
176 }
177 
178 #define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_) \
179  template <InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = \
180  InMemoryDataOperationEnum::Set> \
181  __device__ static bool constexpr IsValidCompilationParameter() \
182  { \
183  return ck::tensor_operation::device::IsValidGemmCompilationParameter< \
184  BlockSize, \
185  MPerBlock, \
186  NPerBlock, \
187  MPerXdl, \
188  NPerXdl, \
189  MXdlPerWave, \
190  NXdlPerWave, \
191  CDataType_, \
192  CGlobalMemoryDataOperation_>(); \
193  }
194 
195 #ifndef CK_CODE_GEN_RTC
197 {
198  BaseArgument() = default;
199  BaseArgument(const BaseArgument&) = default;
200  BaseArgument& operator=(const BaseArgument&) = default;
201 
202  virtual ~BaseArgument() {}
203 
204  void* p_workspace_ = nullptr;
205 };
206 
208 {
209  BaseInvoker() = default;
210  BaseInvoker(const BaseInvoker&) = default;
211  BaseInvoker& operator=(const BaseInvoker&) = default;
212 
213  virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
214  {
215  return float{0};
216  }
217 
218  virtual ~BaseInvoker() {}
219 };
220 #endif
221 
223 {
224  BaseOperator() = default;
225  BaseOperator(const BaseOperator&) = default;
226  BaseOperator& operator=(const BaseOperator&) = default;
227 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
228  virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
229  virtual std::string GetTypeString() const { return ""; }
230 
231  virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
232 
233  virtual std::optional<std::string> GetObjectName() const { return std::nullopt; }
234 
235  virtual std::optional<std::string> GetTemplateInfo() const { return std::nullopt; }
236 
237  virtual std::string GetTypeIdHashCode() const
238  {
239  std::ostringstream oss;
240 
241  oss << std::hex << typeid(*this).hash_code();
242 
243  return oss.str();
244  };
245 
246  virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
247 
248  virtual void SetWorkSpacePointer(BaseArgument* p_arg,
249  void* p_workspace,
250  const StreamConfig& = StreamConfig{}) const
251  {
252  assert(p_arg);
253  p_arg->p_workspace_ = p_workspace;
254  }
255 #endif
256  virtual ~BaseOperator() {}
257 };
258 
259 } // namespace device
260 } // namespace tensor_operation
261 } // namespace ck
Definition: ck.hpp:268
InMemoryDataOperationEnum
Definition: ck.hpp:277
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
int32_t index_t
Definition: ck.hpp:299
Definition: stream_config.hpp:10
Definition: device_base.hpp:197
BaseArgument & operator=(const BaseArgument &)=default
BaseArgument(const BaseArgument &)=default
virtual ~BaseArgument()
Definition: device_base.hpp:202
void * p_workspace_
Definition: device_base.hpp:204
Definition: device_base.hpp:208
virtual ~BaseInvoker()
Definition: device_base.hpp:218
BaseInvoker & operator=(const BaseInvoker &)=default
virtual float Run(const BaseArgument *, const StreamConfig &=StreamConfig{})
Definition: device_base.hpp:213
BaseInvoker(const BaseInvoker &)=default
Definition: device_base.hpp:223
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition: device_base.hpp:248
virtual bool IsSupportedArgument(const BaseArgument *)
Definition: device_base.hpp:228
virtual size_t GetWorkSpaceSize(const BaseArgument *) const
Definition: device_base.hpp:246
virtual std::optional< std::string > GetTemplateInfo() const
Definition: device_base.hpp:235
virtual std::string GetTypeString() const
Definition: device_base.hpp:229
BaseOperator(const BaseOperator &)=default
virtual std::string GetTypeIdHashCode() const
Definition: device_base.hpp:237
virtual std::optional< std::string > GetObjectName() const
Definition: device_base.hpp:233
BaseOperator & operator=(const BaseOperator &)=default
virtual std::string GetTypeIdName() const
Definition: device_base.hpp:231
virtual ~BaseOperator()
Definition: device_base.hpp:256