FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference#
Classes |
Public Types |
Public Member Functions |
Static Public Member Functions |
Static Public Attributes |
List of all members
ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference
#include <fmha_fwd_kernel.hpp>
Classes | |
struct | BlockIndices |
struct | FmhaFwdAlibiKargs |
struct | FmhaFwdBatchModeBiasKargs |
struct | FmhaFwdBatchModeDropoutKargs |
struct | FmhaFwdBatchModeKargs |
struct | FmhaFwdCommonBiasKargs |
struct | FmhaFwdCommonDropoutKargs |
struct | FmhaFwdCommonKargs |
struct | FmhaFwdCommonLSEKargs |
struct | FmhaFwdDropoutSeedOffset |
struct | FmhaFwdEmptyKargs |
struct | FmhaFwdFp8StaticQuantKargs |
struct | FmhaFwdGroupModeKargs |
struct | FmhaFwdLogitsSoftCapKargs |
struct | FmhaFwdMaskKargs |
struct | FmhaFwdSkipMinSeqlenQKargs |
struct | t2s |
struct | t2s< ck_tile::bf16_t > |
struct | t2s< ck_tile::bf8_t > |
struct | t2s< ck_tile::fp16_t > |
struct | t2s< ck_tile::fp8_t > |
struct | t2s< float > |
Public Types | |
using | FmhaPipeline = ck_tile::remove_cvref_t< FmhaPipeline_ > |
using | EpiloguePipeline = ck_tile::remove_cvref_t< EpiloguePipeline_ > |
using | QDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > |
using | KDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > |
using | VDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > |
using | BiasDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > |
using | RandValOutputDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > |
using | LSEDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > |
using | ODataType = ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > |
using | SaccDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > |
using | VLayout = ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > |
using | AttentionVariant = ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > |
using | FmhaMask = ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > |
using | Kargs = std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > |
Public Member Functions | |
CK_TILE_DEVICE void | operator() (Kargs kargs) const |
CK_TILE_DEVICE void | run_ (Kargs kargs) const |
Static Public Member Functions | |
static CK_TILE_HOST std::string | GetName () |
template<bool Cond = !kIsGroupMode> | |
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > | MakeKargsImpl (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset) |
template<bool Cond = !kIsGroupMode> | |
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > | MakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset) |
template<bool Cond = !kIsGroupMode> | |
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > | MakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset) |
template<bool Cond = kIsGroupMode> | |
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > | MakeKargsImpl (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset) |
template<bool Cond = kIsGroupMode> | |
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > | MakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset) |
template<bool Cond = kIsGroupMode> | |
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > | MakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset) |
static constexpr CK_TILE_HOST auto | GridSize (ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false) |
static constexpr CK_TILE_DEVICE auto | GetTileIndex (const Kargs &kargs) |
static constexpr CK_TILE_HOST auto | BlockSize () |
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t | GetSmemSize () |
Static Public Attributes | |
static constexpr ck_tile::index_t | kBlockSize = FmhaPipeline::kBlockSize |
static constexpr ck_tile::index_t | kBlockPerCu = FmhaPipeline::kBlockPerCu |
static constexpr ck_tile::index_t | kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu |
static constexpr bool | kIsGroupMode = FmhaPipeline::kIsGroupMode |
static constexpr bool | kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ |
static constexpr bool | kPadSeqLenK = FmhaPipeline::kPadSeqLenK |
static constexpr bool | kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ |
static constexpr bool | kPadHeadDimV = FmhaPipeline::kPadHeadDimV |
static constexpr bool | kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap |
static constexpr auto | BiasEnum = FmhaPipeline::BiasEnum |
static constexpr bool | kStoreLSE = FmhaPipeline::kStoreLSE |
static constexpr bool | kHasDropout = FmhaPipeline::kHasDropout |
static constexpr bool | kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant |
static constexpr bool | kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ |
static constexpr bool | kHasMask = FmhaMask::IsMasking |
static constexpr bool | kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy |
static constexpr bool | kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad |
static constexpr bool | kIsAvailable = !kUseTrLoad |
static constexpr std::string_view | kPipelineName = FmhaPipeline::name |
Member Typedef Documentation
◆ AttentionVariant
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant> |
◆ BiasDataType
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType> |
◆ EpiloguePipeline
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_> |
◆ FmhaMask
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask> |
◆ FmhaPipeline
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_> |
◆ Kargs
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs> |
◆ KDataType
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType> |
◆ LSEDataType
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType> |
◆ ODataType
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType> |
◆ QDataType
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType> |
◆ RandValOutputDataType
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::RandValOutputDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType> |
◆ SaccDataType
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType> |
◆ VDataType
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType> |
◆ VLayout
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout> |
Member Function Documentation
◆ BlockSize()
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
inlinestaticconstexpr |
◆ GetName()
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
inlinestatic |
◆ GetSmemSize()
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
inlinestaticconstexpr |
◆ GetTileIndex()
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
inlinestaticconstexpr |
◆ GridSize()
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
inlinestaticconstexpr |
◆ MakeKargs() [1/4]
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
template<bool Cond = !kIsGroupMode>
|
inlinestaticconstexpr |
◆ MakeKargs() [2/4]
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
template<bool Cond = !kIsGroupMode>
|
inlinestaticconstexpr |
◆ MakeKargs() [3/4]
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
template<bool Cond = kIsGroupMode>
|
inlinestaticconstexpr |
◆ MakeKargs() [4/4]
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
template<bool Cond = kIsGroupMode>
|
inlinestaticconstexpr |
◆ MakeKargsImpl() [1/2]
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
template<bool Cond = !kIsGroupMode>
|
inlinestaticconstexpr |
◆ MakeKargsImpl() [2/2]
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
template<bool Cond = kIsGroupMode>
|
inlinestaticconstexpr |
◆ operator()()
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
inline |
◆ run_()
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
inline |
FIXME: Before C++20, capturing structured binding variables are not supported. Remove following copy capture of the 'i_nhead' if in C++20
FIXME: Before C++20, capturing structured binding variables are not supported. Remove following copy capture of the 'i_nhead' if in C++20
Member Data Documentation
◆ BiasEnum
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kBlockPerCu
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kBlockPerCuInput
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kBlockSize
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kDoFp8StaticQuant
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kHasDropout
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kHasLogitsSoftCap
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kHasMask
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kIsAvailable
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kIsGroupMode
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kPadHeadDimQ
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kPadHeadDimV
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kPadSeqLenK
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kPadSeqLenQ
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kPipelineName
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kSkipMinSeqlenQ
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kStoreLSE
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kUseAsyncCopy
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
◆ kUseTrLoad
template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
|
staticconstexpr |
The documentation for this struct was generated from the following file:
- /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp