/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File#
streamk_gemm_kernel.hpp
Go to the documentation of this file.
Definition: cluster_descriptor.hpp:13
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
The Stream K GEMM kernel host arguments.
Definition: streamk_gemm_kernel.hpp:20
ck_tile::StreamKReductionStrategy reduction_strategy
Definition: streamk_gemm_kernel.hpp:49
CK_TILE_HOST StreamKHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_, StreamKReductionStrategy reduction_strategy_, uint32_t num_sk_blocks_=0xffffffff)
Definition: streamk_gemm_kernel.hpp:21
ALayout and ADataType are expected to be scalars, not a tuple.
Definition: streamk_gemm_kernel.hpp:93
StreamKReductionStrategy reduction_strategy
The strategy used by work groups to compute final results in C tensor.
Definition: streamk_gemm_kernel.hpp:95
uint32_t num_sk_blocks
The number of stream k blocks.
Definition: streamk_gemm_kernel.hpp:97
void * workspace_ptr
A pointer to a buffer in device memory for accumulating partial via reduction strategy.
Definition: streamk_gemm_kernel.hpp:100
TilePartitioner tile_partitioner
An instance of the TilePartioner class for assisting with mapping workgroups to the C tensor.
Definition: streamk_gemm_kernel.hpp:103
Definition: streamk_gemm_kernel.hpp:55
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: streamk_gemm_kernel.hpp:59
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Specify the layout configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:68
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:73
static CK_TILE_HOST auto GridSize(const TilePartitioner &tile_partitioner) -> dim3
Compute the grid size for the Stream K kernel using the tile_partitioner.
Definition: streamk_gemm_kernel.hpp:125
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: streamk_gemm_kernel.hpp:69
static CK_TILE_HOST StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs &host_args, int num_cu=NumCU(), int occupancy=Occupancy())
Constructs kernel arguments for the Stream-K kernel.
Definition: streamk_gemm_kernel.hpp:151
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: streamk_gemm_kernel.hpp:75
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: streamk_gemm_kernel.hpp:63
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: streamk_gemm_kernel.hpp:65
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, UniversalGemmKernel::NumATensor > &as_ptr, const std::array< const BDataType *, UniversalGemmKernel::NumBTensor > &bs_ptr, const std::array< const void *, UniversalGemmKernel::NumDTensor > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const typename UniversalGemmKernel::KernelArgs &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t k_size)
Definition: streamk_gemm_kernel.hpp:182
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: streamk_gemm_kernel.hpp:134
static CK_TILE_HOST void SetWorkSpacePointer(StreamKKernelArgs &kargs, void *workspace_ptr)
Sets the kargs' current workspace_ptr to the given workspace_ptr.
Definition: streamk_gemm_kernel.hpp:260
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: streamk_gemm_kernel.hpp:74
static constexpr index_t kBlockSize
Definition: streamk_gemm_kernel.hpp:61
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: streamk_gemm_kernel.hpp:64
static CK_TILE_HOST const std::string GetName()
Definition: streamk_gemm_kernel.hpp:109
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: streamk_gemm_kernel.hpp:139
static CK_TILE_HOST bool IsSupportedArgument(const StreamKKernelArgs &kargs)
Definition: streamk_gemm_kernel.hpp:230
CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel, performing the main Stream-K loop.
Definition: streamk_gemm_kernel.hpp:266
static CK_TILE_HOST uint32_t GetWorkSpaceSize(const StreamKKernelArgs &kargs)
Computes the buffer size needed to store accumulation results for Stream K.
Definition: streamk_gemm_kernel.hpp:245
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: streamk_gemm_kernel.hpp:70
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:72
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:71
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:61
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:853
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:754
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:290
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:278
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:319
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202
Definition: integral_constant.hpp:13
Definition: stream_config.hpp:30