/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp Source File#
grouped_flatmm_kernel.hpp
Go to the documentation of this file.
202 using UnderlyingGemmKernel = FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>;
Definition: cluster_descriptor.hpp:13
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
Definition: grouped_flatmm_kernel.hpp:79
void * c_ptr
Definition: grouped_flatmm_kernel.hpp:128
void * e_ptr
Definition: grouped_flatmm_kernel.hpp:127
const std::array< index_t, NumDTensor > stride_Ds
Definition: grouped_flatmm_kernel.hpp:124
const std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_flatmm_kernel.hpp:123
const void * b_shuffle_ptr
Definition: grouped_flatmm_kernel.hpp:121
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs(index_t *M_indices_, index_t M_, index_t N_, index_t K_, const void *a_ptr_, index_t stride_A_, const void *b_shuffle_ptr_, index_t stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void *c_ptr_, index_t stride_C_, index_t k_batch_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr)
Definition: grouped_flatmm_kernel.hpp:81
const void * a_ptr
Definition: grouped_flatmm_kernel.hpp:119
index_t k_batch
Definition: grouped_flatmm_kernel.hpp:131
ScaleM scale_m
Definition: grouped_flatmm_kernel.hpp:132
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs()=default
index_t group_count
Definition: grouped_flatmm_kernel.hpp:114
index_t stride_B
Definition: grouped_flatmm_kernel.hpp:122
index_t stride_A
Definition: grouped_flatmm_kernel.hpp:120
ScaleN scale_n
Definition: grouped_flatmm_kernel.hpp:133
index_t stride_C
Definition: grouped_flatmm_kernel.hpp:130
index_t * M_indices
Definition: grouped_flatmm_kernel.hpp:115
Definition: flatmm_kernel.hpp:229
Definition: flatmm_kernel.hpp:249
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:330
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:253
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:250
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: flatmm_kernel.hpp:258
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: flatmm_kernel.hpp:259
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:254
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:251
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:263
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:264
Definition: flatmm_kernel.hpp:33
Definition: grouped_flatmm_kernel.hpp:19
const void ** b_shuffle_ptr
Definition: grouped_flatmm_kernel.hpp:60
CK_TILE_HOST GroupedFlatmmHostArgs(index_t group_count_, index_t *M_, index_t *N_, index_t *K_, const void **a_ptr_, index_t *stride_A_, const void **b_shuffle_ptr_, index_t *stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void **c_ptr_, index_t *stride_C_, index_t k_batch_, ScaleM *scale_m_=nullptr, ScaleN *scale_n_=nullptr)
Definition: grouped_flatmm_kernel.hpp:21
const std::array< index_t, NumDTensor > stride_Ds
Definition: grouped_flatmm_kernel.hpp:63
const std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_flatmm_kernel.hpp:62
CK_TILE_HOST GroupedFlatmmHostArgs()=default
index_t group_count
Definition: grouped_flatmm_kernel.hpp:54
Definition: grouped_flatmm_kernel.hpp:201
static constexpr index_t NumDTensor
Definition: grouped_flatmm_kernel.hpp:217
static CK_TILE_HOST_DEVICE auto GridSize([[maybe_unused]] const GroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > &kernelArgs)
Definition: grouped_flatmm_kernel.hpp:238
static constexpr index_t kBlockSize
Definition: grouped_flatmm_kernel.hpp:218
static CK_TILE_HOST_DEVICE auto GridSize([[maybe_unused]] const MaskedGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > &kernelArgs)
Definition: grouped_flatmm_kernel.hpp:304
static CK_TILE_HOST const std::string GetName()
Definition: grouped_flatmm_kernel.hpp:228
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: grouped_flatmm_kernel.hpp:213
static constexpr CK_TILE_HOST auto MakeKernelArgs(const HostArgs &hostArgs)
Definition: grouped_flatmm_kernel.hpp:336
CK_TILE_DEVICE void operator()(ContiguousGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > kargs) const
Definition: grouped_flatmm_kernel.hpp:398
CK_TILE_DEVICE void operator()(GroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > kargs) const
Definition: grouped_flatmm_kernel.hpp:354
static CK_TILE_HOST_DEVICE auto GridSize([[maybe_unused]] const ContiguousGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > &kernelArgs)
Definition: grouped_flatmm_kernel.hpp:269
CK_TILE_DEVICE void operator()(MaskedGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > kargs) const
Definition: grouped_flatmm_kernel.hpp:436
Definition: grouped_flatmm_kernel.hpp:140
index_t group_count
Definition: grouped_flatmm_kernel.hpp:178
ScaleM scale_m
Definition: grouped_flatmm_kernel.hpp:195
ScaleN scale_n
Definition: grouped_flatmm_kernel.hpp:196
CK_TILE_HOST MaskedGroupedFlatmmHostArgs(index_t *M_indices_, index_t group_count_, index_t Max_M_, index_t N_, index_t K_, const void *a_ptr_, index_t stride_A_, const void *b_shuffle_ptr_, index_t stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void *c_ptr_, index_t stride_C_, index_t k_batch_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr)
Definition: grouped_flatmm_kernel.hpp:142
CK_TILE_HOST MaskedGroupedFlatmmHostArgs()=default
index_t * M_indices
Definition: grouped_flatmm_kernel.hpp:177
const std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_flatmm_kernel.hpp:186
index_t k_batch
Definition: grouped_flatmm_kernel.hpp:194
const void * b_shuffle_ptr
Definition: grouped_flatmm_kernel.hpp:184
index_t stride_C
Definition: grouped_flatmm_kernel.hpp:193
const void * a_ptr
Definition: grouped_flatmm_kernel.hpp:182
index_t stride_A
Definition: grouped_flatmm_kernel.hpp:183
const std::array< index_t, NumDTensor > stride_Ds
Definition: grouped_flatmm_kernel.hpp:187
index_t stride_B
Definition: grouped_flatmm_kernel.hpp:185
Definition: integral_constant.hpp:13