6 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) 
   17 namespace tensor_operation {
 
   20 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) 
   21 #define GET_OBJECT_NAME_IMLP                                                  \ 
   22     std::optional<std::string> GetObjectName() const override                 \ 
   24         std::string str = __PRETTY_FUNCTION__;                                \ 
   25         static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
 
   27         if(!std::regex_search(str, match, obj_name_expr))                     \
 
   31         return std::string(match[1]) + ';';                                   \
 
   34 #define GET_TEMPLATE_INFO_IMPL                                  \ 
   35     std::optional<std::string> GetTemplateInfo() const override \ 
   37         std::string str = __PRETTY_FUNCTION__;                  \ 
   38         static std::regex template_expr{"\\[(.*)\\]"};          \
 
   40         if(!std::regex_search(str, match, template_expr))       \
 
   42             return std::nullopt;                                \
 
   44         return std::string(match[1]);                           \
 
   47 #define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL 
   57 static constexpr 
auto GetNXdlPerWave2()
 
   59     constexpr 
index_t Waves  = IsWave64 ? BlockSize_ / 64 : BlockSize_ / 32;
 
   60     constexpr 
index_t MWaves = MPerBlock_ / (MXdlPerWave_ * MPerXDL_);
 
   61     static_assert(MWaves > 0);
 
   63     constexpr 
index_t NWaves = Waves / MWaves;
 
   64     if constexpr(NWaves == 0)
 
   70         if constexpr(NPerBlock_ % (NPerXDL_ * NWaves) == 0)
 
   72             return NPerBlock_ / (NWaves * NPerXDL_);
 
   81 #define GET_NXDL_PER_WAVE_IMPL              \ 
   82     template <bool IsWave64>                \ 
   83     static constexpr auto GetNXdlPerWave()  \ 
   85         return GetNXdlPerWave2<BlockSize,   \ 
   94 #define INVOKER_RUN_IMPL                                                               \ 
   95     float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \ 
   97         if(get_warp_size() == 64)                                                      \ 
   99             if constexpr(NXdlPerWave64 > 0)                                            \ 
  101                 return RunImp<GridwiseGemm64>(arg, stream_config);                     \ 
  106             if constexpr(NXdlPerWave32 > 0)                                            \ 
  108                 return RunImp<GridwiseGemm32>(arg, stream_config);                     \ 
  114 #define INVOKER_RUN3_IMPL                                                              \ 
  115     float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \ 
  117         if(get_warp_size() == 64)                                                      \ 
  119             if constexpr(NXdlPerWave64 > 0)                                            \ 
  121                 return RunImp<GridwiseGemm64>(arg, stream_config);                     \ 
  126             if constexpr(NXdlPerWave32 > 0)                                            \ 
  128                 return RunImp<GridwiseGemm32>(                                         \ 
  129                     reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg),   \ 
  145 __device__ 
static bool constexpr IsValidGemmCompilationParameter()
 
  147 #if defined(__gfx11__) || defined(__gfx12__) 
  148     if constexpr(MPerXdl != 16 || NPerXdl != 16)
 
  154 #if defined(__gfx11__) 
  157     constexpr 
bool SupportMemOp =
 
  160     if constexpr(SupportMemOp == 
false)
 
  165     if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0)
 
  167         constexpr 
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
 
  168         constexpr 
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
 
  169         if constexpr(MWaves > 0 && NWaves > 0)
 
  171             constexpr 
index_t WaveSize = BlockSize / (MWaves * NWaves);
 
  178 #define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)                       \ 
  179     template <InMemoryDataOperationEnum CGlobalMemoryDataOperation_ =         \ 
  180                   InMemoryDataOperationEnum::Set>                             \ 
  181     __device__ static bool constexpr IsValidCompilationParameter()            \ 
  183         return ck::tensor_operation::device::IsValidGemmCompilationParameter< \ 
  192             CGlobalMemoryDataOperation_>();                                   \ 
  195 #ifndef CK_CODE_GEN_RTC 
  227 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) 
  234     virtual std::optional<std::string> 
GetObjectName()
 const { 
return std::nullopt; }
 
  240         std::ostringstream oss;
 
  242         oss << std::hex << 
typeid(*this).hash_code();
 
InMemoryDataOperationEnum
Definition: ck.hpp:277
 
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
 
int32_t index_t
Definition: ck.hpp:299
 
Definition: stream_config.hpp:10
 
Definition: device_base.hpp:197
 
BaseArgument & operator=(const BaseArgument &)=default
 
BaseArgument(const BaseArgument &)=default
 
virtual ~BaseArgument()
Definition: device_base.hpp:202
 
void * p_workspace_
Definition: device_base.hpp:204
 
Definition: device_base.hpp:208
 
virtual ~BaseInvoker()
Definition: device_base.hpp:218
 
BaseInvoker & operator=(const BaseInvoker &)=default
 
virtual float Run(const BaseArgument *, const StreamConfig &=StreamConfig{})
Definition: device_base.hpp:213
 
BaseInvoker(const BaseInvoker &)=default
 
Definition: device_base.hpp:223
 
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition: device_base.hpp:249
 
virtual std::string GetInstanceString() const
Definition: device_base.hpp:230
 
virtual bool IsSupportedArgument(const BaseArgument *)
Definition: device_base.hpp:228
 
virtual size_t GetWorkSpaceSize(const BaseArgument *) const
Definition: device_base.hpp:247
 
virtual std::optional< std::string > GetTemplateInfo() const
Definition: device_base.hpp:236
 
virtual std::string GetTypeString() const
Definition: device_base.hpp:229
 
BaseOperator(const BaseOperator &)=default
 
virtual std::string GetTypeIdHashCode() const
Definition: device_base.hpp:238
 
virtual std::optional< std::string > GetObjectName() const
Definition: device_base.hpp:234
 
BaseOperator & operator=(const BaseOperator &)=default
 
virtual std::string GetTypeIdName() const
Definition: device_base.hpp:232
 
virtual ~BaseOperator()
Definition: device_base.hpp:257