/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File#
gridwise_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
@ Odd
Definition: block_to_ctile_map.hpp:270
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:367
CDataType * p_c_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:399
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:368
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:400
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:387
const ADataType * p_a_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:397
const BDataType * p_b_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:398
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:392
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:312
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:348
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:358
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:356
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:362
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:350
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:338
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:349
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:359
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:360
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:354
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:355
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:361
index_t StrideA
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:351
index_t StrideB
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:352
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:313
index_t StrideC
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:353
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:357
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:404
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:455
index_t b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:454
index_t a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:453
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:406
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:104
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:109
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:117
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:128
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:806
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:634
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:111
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:135
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:496
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:185
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:173
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:197
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:162
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:770
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:152
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:106
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:167
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:223
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:157
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:307
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:192
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:809
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:107
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:110
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:113
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:118
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:434
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:119
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:108
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:116
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:112
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:222
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:128
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:135
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:185
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:173
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:197
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:162
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:152
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:106
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:167
typename Base::BlockwiseGemmPipe BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:458
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:223
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:157
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:306
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:307
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:192
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:809
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:107
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, const Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:534
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:434
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:108
static __device__ index_t GetKBlockPerScale()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:465
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:470
Definition: device_base.hpp:51