/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp Source File#
device_grouped_gemm.hpp
Go to the documentation of this file.
141 //----------------------------------------------------------------------------------------------
159 //----------------------------------------------------------------------------------------------
Definition: ck.hpp:267
Definition: device_base.hpp:51
Definition: device_base.hpp:77
virtual std::string GetTypeString() const
Definition: device_base.hpp:83
Definition: device_grouped_gemm.hpp:99
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm.hpp:148
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_a, std::vector< const void * > &p_b, std::vector< std::array< const void *, NumDTensor >> &p_ds, std::vector< void * > &p_e, std::vector< GemmDesc > &gemm_desc, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args, const void *p_host_kernel_args) const
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm.hpp:127
static constexpr index_t NumDTensor
Definition: device_grouped_gemm.hpp:100
virtual size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const
Gets the device kernel argument size.
Definition: device_grouped_gemm.hpp:166
Definition: device_grouped_gemm.hpp:80
ck::index_t stride_C_
Definition: device_grouped_gemm.hpp:82
std::vector< ck::index_t > stride_Ds_
Definition: device_grouped_gemm.hpp:84
ck::index_t stride_A_
Definition: device_grouped_gemm.hpp:82
ck::index_t stride_B_
Definition: device_grouped_gemm.hpp:82
Structure representing single GEMM problem arguments.
Definition: device_grouped_gemm.hpp:29
void Print() const
Definition: device_grouped_gemm.hpp:67
index_t StrideB
Definition: device_grouped_gemm.hpp:63
void * p_e_grid
Definition: device_grouped_gemm.hpp:58
index_t StrideE
Definition: device_grouped_gemm.hpp:65
index_t N
Definition: device_grouped_gemm.hpp:60
const void * p_a_grid
Definition: device_grouped_gemm.hpp:55
index_t K
Definition: device_grouped_gemm.hpp:61
std::array< index_t, NumDTensor > StrideDs
Definition: device_grouped_gemm.hpp:64
index_t StrideA
Definition: device_grouped_gemm.hpp:62
index_t M
Definition: device_grouped_gemm.hpp:59
__host__ __device__ GroupedGemmKernelArgument(const void *p_a_grid_, const void *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, void *p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_)
Definition: device_grouped_gemm.hpp:30
const void * p_b_grid
Definition: device_grouped_gemm.hpp:56
std::array< const void *, NumDTensor > p_ds_grid
Definition: device_grouped_gemm.hpp:57