6 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
17 namespace tensor_operation {
20 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
21 #define GET_OBJECT_NAME_IMLP \
22 std::optional<std::string> GetObjectName() const override \
24 std::string str = __PRETTY_FUNCTION__; \
25 static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
27 if(!std::regex_search(str, match, obj_name_expr)) \
31 return std::string(match[1]) + ';'; \
34 #define GET_TEMPLATE_INFO_IMPL \
35 std::optional<std::string> GetTemplateInfo() const override \
37 std::string str = __PRETTY_FUNCTION__; \
38 static std::regex template_expr{"\\[(.*)\\]"}; \
40 if(!std::regex_search(str, match, template_expr)) \
42 return std::nullopt; \
44 return std::string(match[1]); \
47 #define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
57 static constexpr
auto GetNXdlPerWave2()
59 constexpr
index_t Waves = IsWave64 ? BlockSize_ / 64 : BlockSize_ / 32;
60 constexpr
index_t MWaves = MPerBlock_ / (MXdlPerWave_ * MPerXDL_);
61 static_assert(MWaves > 0);
63 constexpr
index_t NWaves = Waves / MWaves;
64 if constexpr(NWaves == 0)
70 if constexpr(NPerBlock_ % (NPerXDL_ * NWaves) == 0)
72 return NPerBlock_ / (NWaves * NPerXDL_);
81 #define GET_NXDL_PER_WAVE_IMPL \
82 template <bool IsWave64> \
83 static constexpr auto GetNXdlPerWave() \
85 return GetNXdlPerWave2<BlockSize, \
94 #define INVOKER_RUN_IMPL \
95 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
97 if(get_warp_size() == 64) \
99 if constexpr(NXdlPerWave64 > 0) \
101 return RunImp<GridwiseGemm64>(arg, stream_config); \
106 if constexpr(NXdlPerWave32 > 0) \
108 return RunImp<GridwiseGemm32>(arg, stream_config); \
114 #define INVOKER_RUN3_IMPL \
115 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
117 if(get_warp_size() == 64) \
119 if constexpr(NXdlPerWave64 > 0) \
121 return RunImp<GridwiseGemm64>(arg, stream_config); \
126 if constexpr(NXdlPerWave32 > 0) \
128 return RunImp<GridwiseGemm32>( \
129 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg), \
145 __device__
static bool constexpr IsValidGemmCompilationParameter()
147 #if defined(__gfx11__) || defined(__gfx12__)
148 if constexpr(MPerXdl != 16 || NPerXdl != 16)
154 #if defined(__gfx11__)
157 constexpr
bool SupportMemOp =
160 if constexpr(SupportMemOp ==
false)
165 if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0)
167 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
168 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
169 if constexpr(MWaves > 0 && NWaves > 0)
171 constexpr
index_t WaveSize = BlockSize / (MWaves * NWaves);
178 #define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_) \
179 template <InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = \
180 InMemoryDataOperationEnum::Set> \
181 __device__ static bool constexpr IsValidCompilationParameter() \
183 return ck::tensor_operation::device::IsValidGemmCompilationParameter< \
192 CGlobalMemoryDataOperation_>(); \
195 #ifndef CK_CODE_GEN_RTC
227 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
233 virtual std::optional<std::string>
GetObjectName()
const {
return std::nullopt; }
239 std::ostringstream oss;
241 oss << std::hex <<
typeid(*this).hash_code();
InMemoryDataOperationEnum
Definition: ck.hpp:277
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
int32_t index_t
Definition: ck.hpp:299
Definition: stream_config.hpp:10
Definition: device_base.hpp:197
BaseArgument & operator=(const BaseArgument &)=default
BaseArgument(const BaseArgument &)=default
virtual ~BaseArgument()
Definition: device_base.hpp:202
void * p_workspace_
Definition: device_base.hpp:204
Definition: device_base.hpp:208
virtual ~BaseInvoker()
Definition: device_base.hpp:218
BaseInvoker & operator=(const BaseInvoker &)=default
virtual float Run(const BaseArgument *, const StreamConfig &=StreamConfig{})
Definition: device_base.hpp:213
BaseInvoker(const BaseInvoker &)=default
Definition: device_base.hpp:223
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition: device_base.hpp:248
virtual bool IsSupportedArgument(const BaseArgument *)
Definition: device_base.hpp:228
virtual size_t GetWorkSpaceSize(const BaseArgument *) const
Definition: device_base.hpp:246
virtual std::optional< std::string > GetTemplateInfo() const
Definition: device_base.hpp:235
virtual std::string GetTypeString() const
Definition: device_base.hpp:229
BaseOperator(const BaseOperator &)=default
virtual std::string GetTypeIdHashCode() const
Definition: device_base.hpp:237
virtual std::optional< std::string > GetObjectName() const
Definition: device_base.hpp:233
BaseOperator & operator=(const BaseOperator &)=default
virtual std::string GetTypeIdName() const
Definition: device_base.hpp:231
virtual ~BaseOperator()
Definition: device_base.hpp:256