20 namespace tensor_operation {
23 template <
typename ADataType,
30 typename AElementwiseOperation,
31 typename BElementwiseOperation,
32 typename CElementwiseOperation,
43 typename ABlockTransferThreadClusterLengths_K0_M_K1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
49 bool ABlockLdsAddExtraM,
50 typename BBlockTransferThreadClusterLengths_K0_N_K1,
51 typename BBlockTransferThreadClusterArrangeOrder,
52 typename BBlockTransferSrcAccessOrder,
56 bool BBlockLdsAddExtraN,
68 AElementwiseOperation,
69 BElementwiseOperation,
70 CElementwiseOperation>
83 template <index_t NXdlPerWave_>
93 AElementwiseOperation,
94 BElementwiseOperation,
95 CElementwiseOperation,
105 ABlockTransferThreadClusterLengths_K0_M_K1,
106 ABlockTransferThreadClusterArrangeOrder,
107 ABlockTransferSrcAccessOrder,
108 ABlockTransferSrcVectorDim,
109 ABlockTransferSrcScalarPerVector,
110 ABlockTransferDstScalarPerVector_K1,
113 BBlockTransferThreadClusterLengths_K0_N_K1,
114 BBlockTransferThreadClusterArrangeOrder,
115 BBlockTransferSrcAccessOrder,
116 BBlockTransferSrcVectorDim,
117 BBlockTransferSrcScalarPerVector,
118 BBlockTransferDstScalarPerVector_K1,
122 CThreadTransferSrcDstVectorDim,
123 CThreadTransferDstScalarPerVector,
130 using Argument =
typename GridwiseGemm64::Argument;
135 template <
typename Gr
idwiseGemm>
136 float RunImp(
const typename GridwiseGemm::Argument& karg,
139 if(stream_config.log_level_ > 0)
144 if(!GridwiseGemm::CheckValidity(karg))
146 throw std::runtime_error(
147 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting");
150 const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
154 if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
156 const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, true>;
159 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
163 const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, false>;
166 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
178 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
192 if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
193 is_same_v<AccDataType, int32_t>))
200 if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
201 is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
227 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(karg));
240 const BDataType* p_b,
248 AElementwiseOperation,
249 BElementwiseOperation,
250 CElementwiseOperation)
252 return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
267 AElementwiseOperation,
268 BElementwiseOperation,
269 CElementwiseOperation)
override
271 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
272 static_cast<const BDataType*
>(p_b),
273 static_cast<CDataType*
>(p_c),
285 return std::make_unique<Invoker>(
Invoker{});
291 auto str = std::stringstream();
293 std::map<LoopScheduler, std::string> LoopSchedToString{
296 std::map<PipelineVersion, std::string> PipelineVersionToString{{
PipelineVersion::v1,
"v1"},
300 str <<
"DeviceGemmXdl"
305 << K0PerBlock <<
", "
309 << MXdlPerWave <<
", "
310 << NXdlPerWave <<
", "
311 << ABlockTransferSrcScalarPerVector <<
", "
312 << ABlockTransferDstScalarPerVector_K1 <<
", "
313 << BBlockTransferSrcScalarPerVector <<
", "
314 << BBlockTransferDstScalarPerVector_K1
317 << NumPrefetch <<
", "
319 << LoopSchedToString[LoopSched] <<
", "
320 <<
"PipelineVersion: "
321 << PipelineVersionToString[PipelineVer];
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
bool is_lds_direct_load_supported()
Definition: device_prop.hpp:101
std::string get_device_name()
Definition: device_prop.hpp:19
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:299
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdlops_v2r3.hpp:814
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:1003
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Definition: device_gemm.hpp:22
Definition: device_gemm_xdl.hpp:134
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl.hpp:175
float RunImp(const typename GridwiseGemm::Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl.hpp:136
Definition: device_gemm_xdl.hpp:71
static bool IsSupportedArgument(const Argument &karg)
Definition: device_gemm_xdl.hpp:188
static constexpr auto K1Number
Definition: device_gemm_xdl.hpp:80
static constexpr auto I0
Definition: device_gemm_xdl.hpp:76
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl.hpp:234
typename GridwiseGemm64::Argument Argument
Definition: device_gemm_xdl.hpp:130
static auto MakeInvoker()
Definition: device_gemm_xdl.hpp:255
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl.hpp:283
std::string GetTypeString() const override
Definition: device_gemm_xdl.hpp:289
static constexpr auto I2
Definition: device_gemm_xdl.hpp:78
static constexpr auto I1
Definition: device_gemm_xdl.hpp:77
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl.hpp:239
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl.hpp:182
static constexpr auto NXdlPerWave32
Definition: device_gemm_xdl.hpp:74
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_gemm_xdl.hpp:73
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl.hpp:258