/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_base.hpp Source File
device_base.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
7 #include <string>
8 #include <sstream>
9 #include <regex>
10 #include <optional>
11 
12 #include "ck/stream_config.hpp"
13 #endif
14 
15 namespace ck {
16 namespace tensor_operation {
17 namespace device {
18 
19 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
20 #define GET_OBJECT_NAME_IMLP \
21  std::optional<std::string> GetObjectName() const override \
22  { \
23  std::string str = __PRETTY_FUNCTION__; \
24  static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
25  std::smatch match; \
26  if(!std::regex_search(str, match, obj_name_expr)) \
27  { \
28  return str; \
29  } \
30  return std::string(match[1]) + ';'; \
31  }
32 
33 #define GET_TEMPLATE_INFO_IMPL \
34  std::optional<std::string> GetTemplateInfo() const override \
35  { \
36  std::string str = __PRETTY_FUNCTION__; \
37  static std::regex template_expr{"\\[(.*)\\]"}; \
38  std::smatch match; \
39  if(!std::regex_search(str, match, template_expr)) \
40  { \
41  return std::nullopt; \
42  } \
43  return std::string(match[1]); \
44  }
45 
46 #define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
47 #endif
48 
49 #ifndef CK_CODE_GEN_RTC
51 {
52  BaseArgument() = default;
53  BaseArgument(const BaseArgument&) = default;
54  BaseArgument& operator=(const BaseArgument&) = default;
55 
56  virtual ~BaseArgument() {}
57 
58  void* p_workspace_ = nullptr;
59 };
60 
62 {
63  BaseInvoker() = default;
64  BaseInvoker(const BaseInvoker&) = default;
65  BaseInvoker& operator=(const BaseInvoker&) = default;
66 
67  virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
68  {
69  return float{0};
70  }
71 
72  virtual ~BaseInvoker() {}
73 };
74 #endif
75 
77 {
78  BaseOperator() = default;
79  BaseOperator(const BaseOperator&) = default;
80  BaseOperator& operator=(const BaseOperator&) = default;
81 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
82  virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
83  virtual std::string GetTypeString() const { return ""; }
84 
85  virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
86 
87  virtual std::optional<std::string> GetObjectName() const { return std::nullopt; }
88 
89  virtual std::optional<std::string> GetTemplateInfo() const { return std::nullopt; }
90 
91  virtual std::string GetTypeIdHashCode() const
92  {
93  std::ostringstream oss;
94 
95  oss << std::hex << typeid(*this).hash_code();
96 
97  return oss.str();
98  };
99 
100  virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
101 
102  virtual void SetWorkSpacePointer(BaseArgument* p_arg,
103  void* p_workspace,
104  const StreamConfig& = StreamConfig{}) const
105  {
106  assert(p_arg);
107  p_arg->p_workspace_ = p_workspace;
108  }
109 #endif
110  virtual ~BaseOperator() {}
111 };
112 
113 } // namespace device
114 } // namespace tensor_operation
115 } // namespace ck
Definition: ck.hpp:267
Definition: stream_config.hpp:10
Definition: device_base.hpp:51
BaseArgument & operator=(const BaseArgument &)=default
BaseArgument(const BaseArgument &)=default
virtual ~BaseArgument()
Definition: device_base.hpp:56
void * p_workspace_
Definition: device_base.hpp:58
Definition: device_base.hpp:62
virtual ~BaseInvoker()
Definition: device_base.hpp:72
BaseInvoker & operator=(const BaseInvoker &)=default
virtual float Run(const BaseArgument *, const StreamConfig &=StreamConfig{})
Definition: device_base.hpp:67
BaseInvoker(const BaseInvoker &)=default
Definition: device_base.hpp:77
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition: device_base.hpp:102
virtual bool IsSupportedArgument(const BaseArgument *)
Definition: device_base.hpp:82
virtual size_t GetWorkSpaceSize(const BaseArgument *) const
Definition: device_base.hpp:100
virtual std::optional< std::string > GetTemplateInfo() const
Definition: device_base.hpp:89
virtual std::string GetTypeString() const
Definition: device_base.hpp:83
BaseOperator(const BaseOperator &)=default
virtual std::string GetTypeIdHashCode() const
Definition: device_base.hpp:91
virtual std::optional< std::string > GetObjectName() const
Definition: device_base.hpp:87
BaseOperator & operator=(const BaseOperator &)=default
virtual std::string GetTypeIdName() const
Definition: device_base.hpp:85
virtual ~BaseOperator()
Definition: device_base.hpp:110