/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File#
grouped_gemm_kernel.hpp
Go to the documentation of this file.
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: arch.hpp:217
Definition: cluster_descriptor.hpp:13
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
Definition: grouped_gemm_kernel.hpp:83
GemmTransKernelArg(UniversalGemmKernelArgs< 1, 1, NumDTensor > &&karg, index_t bl_start, index_t bl_end)
Definition: grouped_gemm_kernel.hpp:89
UniversalGemmKernelArgs< 1, 1, NumDTensor > group_karg
Definition: grouped_gemm_kernel.hpp:84
GemmTransKernelArg(UniversalGemmKernelArgs< 1, 1, NumDTensor > &&karg)
Definition: grouped_gemm_kernel.hpp:96
GemmTransKernelArg()=delete
ck_tile::index_t block_start
Definition: grouped_gemm_kernel.hpp:85
ck_tile::index_t block_end
Definition: grouped_gemm_kernel.hpp:86
The Grouped GEMM kernel host arguments.
Definition: grouped_gemm_kernel.hpp:29
CK_TILE_HOST GroupedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: grouped_gemm_kernel.hpp:30
const std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_gemm_kernel.hpp:59
const std::array< index_t, NumDTensor > stride_Ds
Definition: grouped_gemm_kernel.hpp:71
Definition: grouped_gemm_kernel.hpp:104
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_gemm_kernel.hpp:109
static constexpr index_t NumDTensor_
Definition: grouped_gemm_kernel.hpp:124
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count) const
Definition: grouped_gemm_kernel.hpp:544
static CK_TILE_HOST auto GetWorkSpaceSize(index_t group_count) -> std::size_t
Definition: grouped_gemm_kernel.hpp:168
static CK_TILE_HOST bool IsSupportedArgument(const std::vector< GemmTransKernelArg< NumDTensor_ >> &kargs)
Definition: grouped_gemm_kernel.hpp:267
static CK_TILE_DEVICE void RunGemmWithPipelineSelection(const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor_ > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const UniversalGemmKernelArgs< 1, 1, NumDTensor_ > &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_gemm_kernel.hpp:374
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: grouped_gemm_kernel.hpp:114
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs< 1, 1, NumDTensor_ > &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
Definition: grouped_gemm_kernel.hpp:284
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: grouped_gemm_kernel.hpp:115
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: grouped_gemm_kernel.hpp:120
static constexpr index_t kBlockSize
Definition: grouped_gemm_kernel.hpp:144
static constexpr CK_TILE_HOST_DEVICE auto GetSmemSize() -> index_t
Definition: grouped_gemm_kernel.hpp:279
static CK_TILE_HOST auto GridSize(const std::vector< GroupedGemmHostArgs< NumDTensor_ >> &gemm_descs)
Definition: grouped_gemm_kernel.hpp:203
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: grouped_gemm_kernel.hpp:116
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_gemm_kernel.hpp:111
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg< NumDTensor_ > *gemm_desc_ptr, index_t block_id, index_t group_count) const
Definition: grouped_gemm_kernel.hpp:493
static CK_TILE_HOST auto BlockSize() -> dim3
Definition: grouped_gemm_kernel.hpp:173
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: grouped_gemm_kernel.hpp:121
static CK_TILE_DEVICE void RunGemmWithPipelineSelection2LDS(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, const std::array< const void *, NumDTensor_ > &ds_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const UniversalGemmKernelArgs< 1, 1, NumDTensor_ > &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_gemm_kernel.hpp:432
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: grouped_gemm_kernel.hpp:191
static CK_TILE_HOST const std::string GetName()
Definition: grouped_gemm_kernel.hpp:147
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_gemm_kernel.hpp:122
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_gemm_kernel.hpp:110
static CK_TILE_HOST auto MakeKargs(const std::vector< GroupedGemmHostArgs< NumDTensor_ >> &gemm_descs) -> std::vector< GemmTransKernelArg< NumDTensor_ >>
Definition: grouped_gemm_kernel.hpp:215
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, C/E.
Definition: grouped_gemm_kernel.hpp:119
static CK_TILE_HOST auto GetWorkSpaceSize(const std::vector< GroupedGemmHostArgs<>> &gemm_descs) -> std::size_t
Definition: grouped_gemm_kernel.hpp:163
static constexpr bool UsePersistentKernel
Definition: grouped_gemm_kernel.hpp:145
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, index_t group_count) const
Definition: grouped_gemm_kernel.hpp:521
Struct used to calculate offseted tile indexes.
Definition: gemm_tile_partitioner.hpp:184
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: universal_gemm_kernel.hpp:325
std::array< index_t, NumATensor > as_k_split_offset
Definition: universal_gemm_kernel.hpp:368
index_t splitted_k
Definition: universal_gemm_kernel.hpp:370
std::array< index_t, NumBTensor > bs_k_split_offset
Definition: universal_gemm_kernel.hpp:369
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:92
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:955
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:853
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:754
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
Definition: stream_config.hpp:30
Definition: tuple.hpp:192