/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp Source File#
mx_flatmm_kernel.hpp
Go to the documentation of this file.
66 return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, MXFlatmmPipeline::GetName());
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1690
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1684
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition: tensor_descriptor.hpp:371
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
Definition: flatmm_kernel.hpp:232
Definition: flatmm_kernel.hpp:252
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:333
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:355
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:359
Definition: mx_flatmm_kernel.hpp:18
static CK_TILE_DEVICE auto MakeABlockWindow(const ADataType *a_ptr, const KernelArgs &kargs, const index_t k_size, const index_t block_idx_m)
Definition: mx_flatmm_kernel.hpp:117
static CK_TILE_DEVICE auto MakeScaleBBlockWindow(const KernelArgs &kargs, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:342
static CK_TILE_HOST const std::string GetName()
Definition: mx_flatmm_kernel.hpp:63
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: mx_flatmm_kernel.hpp:25
static CK_TILE_DEVICE auto MakeScaleABlockWindow(const KernelArgs &kargs, const index_t block_idx_m)
Definition: mx_flatmm_kernel.hpp:312
static constexpr index_t NumDTensor
Definition: mx_flatmm_kernel.hpp:50
remove_cvref_t< typename MXFlatmmPipeline::BLayout > BLayout
Definition: mx_flatmm_kernel.hpp:27
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: mx_flatmm_kernel.hpp:29
static constexpr int NThreadPerXdl
Definition: mx_flatmm_kernel.hpp:40
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: mx_flatmm_kernel.hpp:72
static CK_TILE_DEVICE auto MakeEBlockWindow(EDataType *e_ptr, const KernelArgs &kargs, const index_t block_idx_m, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:259
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: mx_flatmm_kernel.hpp:37
static constexpr bool UsePersistentKernel
Definition: mx_flatmm_kernel.hpp:32
remove_cvref_t< typename MXFlatmmPipeline::CLayout > ELayout
Definition: mx_flatmm_kernel.hpp:28
remove_cvref_t< typename MXFlatmmPipeline::ALayout > ALayout
Definition: mx_flatmm_kernel.hpp:26
remove_cvref_t< typename MXFlatmmPipeline::BDataType > BDataType
Definition: mx_flatmm_kernel.hpp:35
remove_cvref_t< MXFlatmmPipeline_ > MXFlatmmPipeline
Definition: mx_flatmm_kernel.hpp:22
remove_cvref_t< typename MXFlatmmPipeline_::BlockGemmShape > BlockGemmShape
Definition: mx_flatmm_kernel.hpp:24
static constexpr int MThreadPerXdl
Definition: mx_flatmm_kernel.hpp:39
remove_cvref_t< typename MXFlatmmPipeline::ADataType > ADataType
Definition: mx_flatmm_kernel.hpp:34
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: mx_flatmm_kernel.hpp:21
static CK_TILE_DEVICE auto MakeDBlockWindows(const std::array< const void *, NumDTensor > &ds_ptr, const KernelArgs &kargs, const index_t block_idx_m, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:184
static CK_TILE_DEVICE auto MakeBFlatBlockWindow(const BDataType *b_flat_ptr, const KernelArgs &kargs, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:148
typename Underlying::SplitKBatchOffset SplitKBatchOffset
Definition: mx_flatmm_kernel.hpp:114
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: mx_flatmm_kernel.hpp:454
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: mx_flatmm_kernel.hpp:30
static constexpr index_t KernelBlockSize
Definition: mx_flatmm_kernel.hpp:31
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:373
static constexpr int KThreadPerXdl
Definition: mx_flatmm_kernel.hpp:41
static constexpr int BPackedSize
Definition: mx_flatmm_kernel.hpp:44
static constexpr int APackedSize
Definition: mx_flatmm_kernel.hpp:43
Definition: integral_constant.hpp:13
Definition: numeric.hpp:81
Definition: sequence.hpp:49