/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp Source File#
device_gemm_wmma_cshuffle_v3r1.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
__global__ void kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:35
@ Intrawave
@ Interwave
Definition: stream_config.hpp:10
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:408
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:460
EDataType * p_e_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:473
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:362
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:388
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:390
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:389
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:395
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:170
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:941
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:1154
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:231
Definition: multi_index_transform.hpp:13
Definition: sequence.hpp:43
Definition: tuple.hpp:186
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
static constexpr bool value
Definition: integral_constant.hpp:21
Definition: type.hpp:177
Definition: reduction_operator.hpp:37
Definition: device_base.hpp:197
void * p_workspace_
Definition: device_base.hpp:204
Definition: device_base.hpp:208
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:149
Argument(std::array< const void *, 1 > p_a_grid_, std::array< const void *, 1 > p_b_grid_, const ::std::array< const void *, NumDTensor > p_ds_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, 1 > StrideA_, std::array< index_t, 1 > StrideB_, const ::std::array< index_t, NumDTensor > stride_ds_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:150
CDataType * p_c_grid
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:188
const ::std::array< const void *, NumDTensor > p_ds
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:190
CElementwiseOperation c_element_op
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:189
::std::array< index_t, NumDTensor > StrideDs
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:191
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:226
float RunReduce(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:227
float Run(const Argument &arg_, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:299
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:368
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:91
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:94
DeviceReduceThreadWiseMultiD< ReduceDataType, DsDataType, GemmAccDataType, CDataType, 3, 1, ReduceAdd, PassThrough, OutElementwiseOperation, 256, CShuffleBlockTransferScalarPerVector_NPerBlock, 1, 0, CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, decltype(DsVectorLengthSequence)> DeviceReduceInstance
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:223
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:375
static size_t GetSharedMemoryNumberOfByte()
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:412
static constexpr index_t NumDTensor
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:92
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const ::std::array< const void *, NumDTensor > p_ds, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, const ::std::array< index_t, NumDTensor > stride_ds, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:417
static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:405
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:381
static constexpr auto DsVectorLengthSequence
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:197
::std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:453
static auto MakeInvoker()
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:450
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, GemmAccDataType, ReduceDataType, Tuple<>, ReduceDataType, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, false, false > GridwiseGemm
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:146
ck::reduce::Add ReduceAdd
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:194
CElementwiseOperation OutElementwiseOperation
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:195
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:400
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:546
static constexpr index_t GetBlockSize()
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:410
::std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, ::std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, ::std::array< index_t, NumDTensor > DsStrides, index_t StrideC, index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:459
::std::string GetTypeString() const override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:492
Definition: device_gemm_v2.hpp:57
Definition: device_reduce_threadwise_multi_d.hpp:47
Definition: unary_element_wise_operation.hpp:334