/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp Source File#
blockwise_gemm_xdlops.hpp
Go to the documentation of this file.
390 // the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:207
Definition: ck.hpp:266
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:606
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
@ Default
@ Interwave
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
static constexpr index_t KPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:56
static constexpr index_t A_K1
Definition: blockwise_gemm_smfmac_xdlops.hpp:61
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:426
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
Definition: blockwise_gemm_xdlops.hpp:165
static constexpr auto I2
Definition: blockwise_gemm_smfmac_xdlops.hpp:47
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:104
static constexpr index_t WaveSize
Definition: blockwise_gemm_smfmac_xdlops.hpp:52
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_xdlops.hpp:251
static constexpr index_t KPerThread
Definition: blockwise_gemm_smfmac_xdlops.hpp:67
ck::BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_xdlops.hpp:234
static constexpr index_t B_K1
Definition: blockwise_gemm_smfmac_xdlops.hpp:62
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:178
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:93
static constexpr index_t MPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:54
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_smfmac_xdlops.hpp:77
static constexpr auto b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_smfmac_xdlops.hpp:294
static constexpr index_t NPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:55
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:146
static constexpr auto I0
Definition: blockwise_gemm_smfmac_xdlops.hpp:45
ThreadwiseTensorSliceTransfer_v4< FloatA, FloatA, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition: blockwise_gemm_smfmac_xdlops.hpp:437
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:418
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_smfmac_xdlops.hpp:450
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:50
__host__ static constexpr __device__ auto MakeBBlockDescriptor_N0_N1_N2_K()
Definition: blockwise_gemm_xdlops.hpp:281
__host__ static constexpr __device__ auto MakeABlockDescriptor_M0_M1_M2_K()
Definition: blockwise_gemm_xdlops.hpp:269
static constexpr auto a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_smfmac_xdlops.hpp:293
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_smfmac_xdlops.hpp:449
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_xdlops.hpp:79
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:217
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:191
static constexpr index_t NWaves
Definition: blockwise_gemm_smfmac_xdlops.hpp:70
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_smfmac_xdlops.hpp:64
static constexpr index_t B_K0
Definition: blockwise_gemm_smfmac_xdlops.hpp:60
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:422
static constexpr index_t A_K0
Definition: blockwise_gemm_smfmac_xdlops.hpp:59
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_xdlops.hpp:297
ck::BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:204
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:117
ThreadwiseTensorSliceTransfer_v4< FloatB, ComputeTypeB, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition: blockwise_gemm_smfmac_xdlops.hpp:447
static constexpr auto I3
Definition: blockwise_gemm_smfmac_xdlops.hpp:48
static constexpr auto I1
Definition: blockwise_gemm_smfmac_xdlops.hpp:46
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_xdlops.hpp:81
static constexpr index_t MWaves
Definition: blockwise_gemm_smfmac_xdlops.hpp:69
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:818
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:831
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_xdlops.hpp:707
static constexpr index_t A_K0
Definition: blockwise_gemm_xdlops.hpp:684
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_xdlops.hpp:689
static constexpr index_t A_K1
Definition: blockwise_gemm_xdlops.hpp:686
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_xdlops.hpp:985
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_xdlops.hpp:709
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:859
static constexpr index_t NWaves
Definition: blockwise_gemm_xdlops.hpp:695
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_xdlops.hpp:925
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_xdlops.hpp:845
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_xdlops.hpp:928
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Definition: blockwise_gemm_xdlops.hpp:789
static constexpr index_t B_K0
Definition: blockwise_gemm_xdlops.hpp:685
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:745
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_xdlops.hpp:889
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:872
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition: blockwise_gemm_xdlops.hpp:787
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_xdlops.hpp:981
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_xdlops.hpp:989
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_xdlops.hpp:804
static constexpr index_t WaveSize
Definition: blockwise_gemm_xdlops.hpp:682
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_xdlops.hpp:906
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:721
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:732
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_xdlops.hpp:924
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_xdlops.hpp:680
static constexpr index_t MWaves
Definition: blockwise_gemm_xdlops.hpp:694
static constexpr index_t B_K1
Definition: blockwise_gemm_xdlops.hpp:687
static constexpr index_t KPerThread
Definition: blockwise_gemm_xdlops.hpp:692
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_xdlops.hpp:1012
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_xdlops.hpp:698
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:774
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_xdlops.hpp:1013
Definition: blockwise_gemm_xdlops.hpp:420
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
Definition: threadwise_tensor_slice_transfer.hpp:1260
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: xdlops_gemm.hpp:1669
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10