/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_base.hpp File Reference

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_base.hpp File Reference#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_base.hpp File Reference
device_base.hpp File Reference
#include <string>
#include <sstream>
#include <regex>
#include <optional>
#include "ck/stream_config.hpp"
#include "ck/utility/get_id.hpp"

Go to the source code of this file.

Classes

struct  ck::tensor_operation::device::BaseArgument
 
struct  ck::tensor_operation::device::BaseInvoker
 
struct  ck::tensor_operation::device::BaseOperator
 

Namespaces

 ck
 
 ck::tensor_operation
 
 ck::tensor_operation::device
 

Macros

#define GET_OBJECT_NAME_IMLP
 
#define GET_TEMPLATE_INFO_IMPL
 
#define REGISTER_EXTRA_PRINTING_METHODS   GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
 
#define GET_NXDL_PER_WAVE_IMPL
 
#define INVOKER_RUN_IMPL
 
#define INVOKER_RUN3_IMPL
 
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
 

Macro Definition Documentation

◆ GET_NXDL_PER_WAVE_IMPL

#define GET_NXDL_PER_WAVE_IMPL
Value:
template <bool IsWave64> \
static constexpr auto GetNXdlPerWave() \
{ \
return GetNXdlPerWave2<BlockSize, \
MPerBlock, \
NPerBlock, \
MPerXDL, \
NPerXDL, \
MXdlPerWave, \
IsWave64>(); \
}

◆ GET_OBJECT_NAME_IMLP

#define GET_OBJECT_NAME_IMLP
Value:
std::optional<std::string> GetObjectName() const override \
{ \
std::string str = __PRETTY_FUNCTION__; \
static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
std::smatch match; \
if(!std::regex_search(str, match, obj_name_expr)) \
{ \
return str; \
} \
return std::string(match[1]) + ';'; \
}

◆ GET_TEMPLATE_INFO_IMPL

#define GET_TEMPLATE_INFO_IMPL
Value:
std::optional<std::string> GetTemplateInfo() const override \
{ \
std::string str = __PRETTY_FUNCTION__; \
static std::regex template_expr{"\\[(.*)\\]"}; \
std::smatch match; \
if(!std::regex_search(str, match, template_expr)) \
{ \
return std::nullopt; \
} \
return std::string(match[1]); \
}

◆ INVOKER_RUN3_IMPL

#define INVOKER_RUN3_IMPL
Value:
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
{ \
if(get_warp_size() == 64) \
{ \
if constexpr(NXdlPerWave64 > 0) \
{ \
return RunImp<GridwiseGemm64>(arg, stream_config); \
} \
} \
else \
{ \
if constexpr(NXdlPerWave32 > 0) \
{ \
return RunImp<GridwiseGemm32>( \
reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg), \
stream_config); \
} \
} \
return 0; \
}
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: stream_config.hpp:10

◆ INVOKER_RUN_IMPL

#define INVOKER_RUN_IMPL
Value:
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
{ \
if(get_warp_size() == 64) \
{ \
if constexpr(NXdlPerWave64 > 0) \
{ \
return RunImp<GridwiseGemm64>(arg, stream_config); \
} \
} \
else \
{ \
if constexpr(NXdlPerWave32 > 0) \
{ \
return RunImp<GridwiseGemm32>(arg, stream_config); \
} \
} \
return 0; \
}

◆ IS_VALID_COMPILATION_PARAMETER_IMPL

#define IS_VALID_COMPILATION_PARAMETER_IMPL (   CDataType_)
Value:
template <InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = \
InMemoryDataOperationEnum::Set> \
__device__ static bool constexpr IsValidCompilationParameter() \
{ \
return ck::tensor_operation::device::IsValidGemmCompilationParameter< \
BlockSize, \
MPerBlock, \
NPerBlock, \
MPerXdl, \
NPerXdl, \
MXdlPerWave, \
NXdlPerWave, \
CDataType_, \
CGlobalMemoryDataOperation_>(); \
}
InMemoryDataOperationEnum
Definition: ck.hpp:277

◆ REGISTER_EXTRA_PRINTING_METHODS

#define REGISTER_EXTRA_PRINTING_METHODS   GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL