GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference

GroupedGemmKernel&lt; TilePartitioner_, GemmPipeline_, EpiloguePipeline_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference
ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference

#include <grouped_gemm_kernel.hpp>

Public Types

using Base = UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
 Inject the UniversalGemmKernel base class to support execution of all necessary functions. More...
 
using TilePartitioner = remove_cvref_t< TilePartitioner_ >
 
using GemmPipeline = remove_cvref_t< GemmPipeline_ >
 
using EpiloguePipeline = remove_cvref_t< EpiloguePipeline_ >
 
using ALayout = remove_cvref_t< typename GemmPipeline::ALayout >
 
using BLayout = remove_cvref_t< typename GemmPipeline::BLayout >
 
using CLayout = remove_cvref_t< typename GemmPipeline::CLayout >
 
using ADataType = remove_cvref_t< typename GemmPipeline::ADataType >
 Specify the data type configurations for A, B, C/E. More...
 
using BDataType = remove_cvref_t< typename GemmPipeline::BDataType >
 
using CDataType = remove_cvref_t< typename EpiloguePipeline::ODataType >
 
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner< TilePartitioner >
 ALayout and ADataType are expected to be scalars, not a tuple. More...
 
using Kernel = GroupedGemmKernel< TilePartitioner, GemmPipeline, EpiloguePipeline >
 

Public Member Functions

CK_TILE_DEVICE void Run (const UniversalGemmKernelArgs<> &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
 
CK_TILE_DEVICE index_t FindGroupId (const GemmTransKernelArg *gemm_desc_ptr, index_t block_id, index_t group_count) const
 
template<bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
CK_TILE_DEVICE void operator() (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, index_t group_count) const
 
template<bool U = UsePersistentKernel, typename = std::enable_if_t<U>, typename = void>
CK_TILE_DEVICE void operator() (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count) const
 

Static Public Member Functions

static CK_TILE_HOST const std::string GetName ()
 
static CK_TILE_HOST auto GetWorkSpaceSize (const std::vector< GroupedGemmHostArgs > &gemm_descs) -> std::size_t
 
static CK_TILE_HOST auto GetWorkSpaceSize (index_t group_count) -> std::size_t
 
static CK_TILE_HOST auto BlockSize () -> dim3
 
static CK_TILE_HOST auto MaxOccupancyGridSize (const stream_config &s) -> dim3
 Get the maximum occupancy grid size for the persistent kernel on the current device. More...
 
static CK_TILE_HOST auto GridSize (const std::vector< GroupedGemmHostArgs > &gemm_descs)
 
static CK_TILE_HOST auto MakeKargs (const std::vector< GroupedGemmHostArgs > &gemm_descs) -> std::vector< GemmTransKernelArg >
 
static CK_TILE_HOST bool IsSupportedArgument (const std::vector< GemmTransKernelArg > &kargs)
 
static constexpr CK_TILE_HOST_DEVICE auto GetSmemSize () -> index_t
 
static CK_TILE_DEVICE void RunGemmWithPipelineSelection (const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, void *smem_ptr_0, const UniversalGemmKernelArgs<> &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
 Runs single GEMM problem cooperatively by whole workgroup. More...
 
static CK_TILE_DEVICE void RunGemmWithPipelineSelection2LDS (const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const UniversalGemmKernelArgs<> &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
 Runs single GEMM problem cooperatively by whole workgroup. More...
 

Static Public Attributes

static constexpr index_t kBlockSize = GemmPipeline::BlockSize
 
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel
 

Member Typedef Documentation

◆ ADataType

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::ADataType = remove_cvref_t<typename GemmPipeline::ADataType>

Specify the data type configurations for A, B, C/E.

◆ ALayout

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::ALayout = remove_cvref_t<typename GemmPipeline::ALayout>

◆ Base

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::Base = UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>

Inject the UniversalGemmKernel base class to support execution of all necessary functions.

◆ BDataType

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BDataType = remove_cvref_t<typename GemmPipeline::BDataType>

◆ BLayout

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BLayout = remove_cvref_t<typename GemmPipeline::BLayout>

◆ CDataType

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>

◆ CLayout

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::CLayout = remove_cvref_t<typename GemmPipeline::CLayout>

◆ EpiloguePipeline

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>

◆ GemmPipeline

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GemmPipeline = remove_cvref_t<GemmPipeline_>

◆ Kernel

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>

◆ OffsetTile1DPartitioner

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>

ALayout and ADataType are expected to be scalars, not a tuple.

BLayout and BDataType are expected to be scalars, not a tuple.

C/ELayout and C/EDataType are expected to be scalars, not a tuple.

◆ TilePartitioner

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::TilePartitioner = remove_cvref_t<TilePartitioner_>

Member Function Documentation

◆ BlockSize()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BlockSize ( ) -> dim3
inlinestatic

◆ FindGroupId()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
CK_TILE_DEVICE index_t ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::FindGroupId ( const GemmTransKernelArg gemm_desc_ptr,
index_t  block_id,
index_t  group_count 
) const
inline

◆ GetName()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST const std::string ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GetName ( )
inlinestatic

◆ GetSmemSize()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static constexpr CK_TILE_HOST_DEVICE auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GetSmemSize ( ) -> index_t
inlinestaticconstexpr

◆ GetWorkSpaceSize() [1/2]

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GetWorkSpaceSize ( const std::vector< GroupedGemmHostArgs > &  gemm_descs) -> std::size_t
inlinestatic

◆ GetWorkSpaceSize() [2/2]

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GetWorkSpaceSize ( index_t  group_count) -> std::size_t
inlinestatic

◆ GridSize()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GridSize ( const std::vector< GroupedGemmHostArgs > &  gemm_descs)
inlinestatic

◆ IsSupportedArgument()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST bool ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::IsSupportedArgument ( const std::vector< GemmTransKernelArg > &  kargs)
inlinestatic

◆ MakeKargs()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::MakeKargs ( const std::vector< GroupedGemmHostArgs > &  gemm_descs) -> std::vector<GemmTransKernelArg>
inlinestatic

◆ MaxOccupancyGridSize()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::MaxOccupancyGridSize ( const stream_config s) -> dim3
inlinestatic

Get the maximum occupancy grid size for the persistent kernel on the current device.

Returns
The maximum occupancy grid size.
Note
This function queries the maximum occupancy of the kernel using hipOccupancyMaxActiveBlocksPerMultiprocessor.

◆ operator()() [1/2]

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
template<bool U = UsePersistentKernel, typename = std::enable_if_t<U>, typename = void>
CK_TILE_DEVICE void ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::operator() ( const void CK_CONSTANT_ADDRESS_SPACE gemm_descs_const,
const index_t  group_count 
) const
inline

◆ operator()() [2/2]

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
template<bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
CK_TILE_DEVICE void ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::operator() ( const void CK_CONSTANT_ADDRESS_SPACE gemm_descs_const,
index_t  group_count 
) const
inline

◆ Run()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
CK_TILE_DEVICE void ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::Run ( const UniversalGemmKernelArgs<> &  kargs,
const tuple< index_t, index_t > &  block_idx_2d,
const index_t  block_idx_z 
) const
inline

◆ RunGemmWithPipelineSelection()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_DEVICE void ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::RunGemmWithPipelineSelection ( const ADataType a_ptr,
const BDataType b_ptr,
CDataType c_ptr,
void *  smem_ptr_0,
const UniversalGemmKernelArgs<> &  kargs,
const typename Base::SplitKBatchOffset splitk_batch_offset,
const index_t  block_idx_m,
const index_t  block_idx_n 
)
inlinestatic

Runs single GEMM problem cooperatively by whole workgroup.

Note
The GEMM pipeline is selected in-kernel based on the number of K-loops and the tail-number. This is needed for the persistent tile-loop when we didn't have access to the K dimension on the host.
Parameters
a_ptrinput A pointer
b_ptrinput B pointer
c_ptroutput C pointer
smem_ptr_0The start memory pointer of the shared memory block.
kargsGEMM kernel arguments
splitk_batch_offsetsplitk_batch_offset Utility structure used to calculate k batch.
block_idx_mThe GEMM's output M dimension tile index processed by this workgroup.
block_idx_nThe GEMM's output N dimension tile index processed by this workgroup.

◆ RunGemmWithPipelineSelection2LDS()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_DEVICE void ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::RunGemmWithPipelineSelection2LDS ( const ADataType a_ptr,
const BDataType b_ptr,
CDataType c_ptr,
void *__restrict__  smem_ptr_0,
void *__restrict__  smem_ptr_1,
const UniversalGemmKernelArgs<> &  kargs,
const typename Base::SplitKBatchOffset splitk_batch_offset,
const index_t  block_idx_m,
const index_t  block_idx_n 
)
inlinestatic

Runs single GEMM problem cooperatively by whole workgroup.

Note
The GEMM pipeline is selected in-kernel based on the number of K-loops and the tail-number. This is needed for the persistent tile-loop when we didn't have access to the K dimension on the host.
Parameters
a_ptrinput A pointer
b_ptrinput B pointer
c_ptroutput C pointer
smem_ptr_0The start memory pointer of the shared memory block.
smem_ptr_1The second start memory pointer of the shared memory block.
kargsGEMM kernel arguments
splitk_batch_offsetsplitk_batch_offset Utility structure used to calculate k batch.
block_idx_mThe GEMM's output M dimension tile index processed by this workgroup.
block_idx_nThe GEMM's output N dimension tile index processed by this workgroup.

Member Data Documentation

◆ kBlockSize

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
constexpr index_t ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::kBlockSize = GemmPipeline::BlockSize
staticconstexpr

◆ UsePersistentKernel

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::UsePersistentKernel = GemmPipeline::UsePersistentKernel
staticconstexpr

The documentation for this struct was generated from the following file:
  • /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp