FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference

FmhaFwdKernel&lt; FmhaPipeline_, EpiloguePipeline_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference
ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference

#include <fmha_fwd_kernel.hpp>

Classes

struct  BlockIndices
 
struct  FmhaFwdAlibiKargs
 
struct  FmhaFwdBatchModeBiasKargs
 
struct  FmhaFwdBatchModeDropoutKargs
 
struct  FmhaFwdBatchModeKargs
 
struct  FmhaFwdCommonBiasKargs
 
struct  FmhaFwdCommonDropoutKargs
 
struct  FmhaFwdCommonKargs
 
struct  FmhaFwdCommonLSEKargs
 
struct  FmhaFwdDropoutSeedOffset
 
struct  FmhaFwdEmptyKargs
 
struct  FmhaFwdFp8StaticQuantKargs
 
struct  FmhaFwdGroupModeKargs
 
struct  FmhaFwdLogitsSoftCapKargs
 
struct  FmhaFwdMaskKargs
 
struct  FmhaFwdSkipMinSeqlenQKargs
 
struct  t2s
 
struct  t2s< ck_tile::bf16_t >
 
struct  t2s< ck_tile::bf8_t >
 
struct  t2s< ck_tile::fp16_t >
 
struct  t2s< ck_tile::fp8_t >
 
struct  t2s< float >
 

Public Types

using FmhaPipeline = ck_tile::remove_cvref_t< FmhaPipeline_ >
 
using EpiloguePipeline = ck_tile::remove_cvref_t< EpiloguePipeline_ >
 
using QDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType >
 
using KDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType >
 
using VDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType >
 
using BiasDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType >
 
using RandValOutputDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType >
 
using LSEDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType >
 
using ODataType = ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType >
 
using SaccDataType = ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType >
 
using VLayout = ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout >
 
using AttentionVariant = ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant >
 
using FmhaMask = ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask >
 
using Kargs = std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs >
 

Public Member Functions

CK_TILE_DEVICE void operator() (Kargs kargs) const
 
CK_TILE_DEVICE void run_ (Kargs kargs) const
 

Static Public Member Functions

static CK_TILE_HOST std::string GetName ()
 
template<bool Cond = !kIsGroupMode>
static constexpr CK_TILE_HOST std::enable_if_t< Cond, KargsMakeKargsImpl (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
 
template<bool Cond = !kIsGroupMode>
static constexpr CK_TILE_HOST std::enable_if_t< Cond, KargsMakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
 
template<bool Cond = !kIsGroupMode>
static constexpr CK_TILE_HOST std::enable_if_t< Cond, KargsMakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset)
 
template<bool Cond = kIsGroupMode>
static constexpr CK_TILE_HOST std::enable_if_t< Cond, KargsMakeKargsImpl (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
 
template<bool Cond = kIsGroupMode>
static constexpr CK_TILE_HOST std::enable_if_t< Cond, KargsMakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
 
template<bool Cond = kIsGroupMode>
static constexpr CK_TILE_HOST std::enable_if_t< Cond, KargsMakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset)
 
static constexpr CK_TILE_HOST auto GridSize (ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
 
static constexpr CK_TILE_DEVICE auto GetTileIndex (const Kargs &kargs)
 
static constexpr CK_TILE_HOST auto BlockSize ()
 
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize ()
 

Static Public Attributes

static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize
 
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu
 
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu
 
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode
 
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ
 
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK
 
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ
 
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV
 
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap
 
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum
 
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE
 
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout
 
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant
 
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ
 
static constexpr bool kHasMask = FmhaMask::IsMasking
 
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy
 
static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad
 
static constexpr bool kIsAvailable = !kUseTrLoad
 
static constexpr std::string_view kPipelineName = FmhaPipeline::name
 

Member Typedef Documentation

◆ AttentionVariant

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>

◆ BiasDataType

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>

◆ EpiloguePipeline

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>

◆ FmhaMask

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>

◆ FmhaPipeline

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>

◆ Kargs

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>

◆ KDataType

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>

◆ LSEDataType

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>

◆ ODataType

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>

◆ QDataType

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>

◆ RandValOutputDataType

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::RandValOutputDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>

◆ SaccDataType

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>

◆ VDataType

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>

◆ VLayout

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>

Member Function Documentation

◆ BlockSize()

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
static constexpr CK_TILE_HOST auto ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::BlockSize ( )
inlinestaticconstexpr

◆ GetName()

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST std::string ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::GetName ( )
inlinestatic

◆ GetSmemSize()

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ GetTileIndex()

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
static constexpr CK_TILE_DEVICE auto ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::GetTileIndex ( const Kargs kargs)
inlinestaticconstexpr

◆ GridSize()

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
static constexpr CK_TILE_HOST auto ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::GridSize ( ck_tile::index_t  batch_size_,
ck_tile::index_t  nhead_,
ck_tile::index_t  seqlen_q_,
ck_tile::index_t  hdim_v_,
bool  has_padded_seqlen_k = false 
)
inlinestaticconstexpr

◆ MakeKargs() [1/4]

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
template<bool Cond = !kIsGroupMode>
static constexpr CK_TILE_HOST std::enable_if_t<Cond, Kargs> ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void *  q_ptr,
const void *  k_ptr,
const void *  v_ptr,
const void *  bias_ptr,
void *  rand_val_ptr,
void *  lse_ptr,
void *  o_ptr,
ck_tile::index_t  seqlen_q,
ck_tile::index_t  seqlen_k,
ck_tile::index_t  hdim_q,
ck_tile::index_t  hdim_v,
ck_tile::index_t  num_head_q,
ck_tile::index_t  nhead_ratio_qk,
float  scale_s,
float  scale_p,
float  scale_o,
float  logits_soft_cap,
ck_tile::index_t  stride_q,
ck_tile::index_t  stride_k,
ck_tile::index_t  stride_v,
ck_tile::index_t  stride_bias,
ck_tile::index_t  stride_randval,
ck_tile::index_t  stride_o,
ck_tile::index_t  nhead_stride_q,
ck_tile::index_t  nhead_stride_k,
ck_tile::index_t  nhead_stride_v,
ck_tile::index_t  nhead_stride_bias,
ck_tile::index_t  nhead_stride_randval,
ck_tile::index_t  nhead_stride_lse,
ck_tile::index_t  nhead_stride_o,
ck_tile::index_t  batch_stride_q,
ck_tile::index_t  batch_stride_k,
ck_tile::index_t  batch_stride_v,
ck_tile::index_t  batch_stride_bias,
ck_tile::index_t  batch_stride_randval,
ck_tile::index_t  batch_stride_lse,
ck_tile::index_t  batch_stride_o,
ck_tile::index_t  window_size_left,
ck_tile::index_t  window_size_right,
ck_tile::index_t  mask_type,
float  p_drop,
bool  s_randval,
const std::tuple< const void *, const void * > &  drop_seed_offset 
)
inlinestaticconstexpr

◆ MakeKargs() [2/4]

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
template<bool Cond = !kIsGroupMode>
static constexpr CK_TILE_HOST std::enable_if_t<Cond, Kargs> ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void *  q_ptr,
const void *  k_ptr,
const void *  v_ptr,
const void *  bias_ptr,
void *  rand_val_ptr,
void *  lse_ptr,
void *  o_ptr,
ck_tile::index_t  seqlen_q,
ck_tile::index_t  seqlen_k,
ck_tile::index_t  hdim_q,
ck_tile::index_t  hdim_v,
ck_tile::index_t  num_head_q,
ck_tile::index_t  nhead_ratio_qk,
float  scale_s,
float  scale_p,
float  scale_o,
float  logits_soft_cap,
ck_tile::index_t  stride_q,
ck_tile::index_t  stride_k,
ck_tile::index_t  stride_v,
ck_tile::index_t  stride_bias,
ck_tile::index_t  stride_randval,
ck_tile::index_t  stride_o,
ck_tile::index_t  nhead_stride_q,
ck_tile::index_t  nhead_stride_k,
ck_tile::index_t  nhead_stride_v,
ck_tile::index_t  nhead_stride_bias,
ck_tile::index_t  nhead_stride_randval,
ck_tile::index_t  nhead_stride_lse,
ck_tile::index_t  nhead_stride_o,
ck_tile::index_t  batch_stride_q,
ck_tile::index_t  batch_stride_k,
ck_tile::index_t  batch_stride_v,
ck_tile::index_t  batch_stride_bias,
ck_tile::index_t  batch_stride_randval,
ck_tile::index_t  batch_stride_lse,
ck_tile::index_t  batch_stride_o,
ck_tile::index_t  window_size_left,
ck_tile::index_t  window_size_right,
ck_tile::index_t  mask_type,
float  p_drop,
bool  s_randval,
const std::tuple< uint64_t, uint64_t > &  drop_seed_offset 
)
inlinestaticconstexpr

◆ MakeKargs() [3/4]

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
template<bool Cond = kIsGroupMode>
static constexpr CK_TILE_HOST std::enable_if_t<Cond, Kargs> ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void *  q_ptr,
const void *  k_ptr,
const void *  v_ptr,
const void *  bias_ptr,
void *  rand_val_ptr,
void *  lse_ptr,
void *  o_ptr,
const void *  seqstart_q_ptr,
const void *  seqstart_k_ptr,
const void *  seqlen_k_ptr,
ck_tile::index_t  hdim_q,
ck_tile::index_t  hdim_v,
ck_tile::index_t  num_head_q,
ck_tile::index_t  nhead_ratio_qk,
float  scale_s,
float  scale_p,
float  scale_o,
float  logits_soft_cap,
ck_tile::index_t  stride_q,
ck_tile::index_t  stride_k,
ck_tile::index_t  stride_v,
ck_tile::index_t  stride_bias,
ck_tile::index_t  stride_randval,
ck_tile::index_t  stride_o,
ck_tile::index_t  nhead_stride_q,
ck_tile::index_t  nhead_stride_k,
ck_tile::index_t  nhead_stride_v,
ck_tile::index_t  nhead_stride_bias,
ck_tile::index_t  nhead_stride_randval,
ck_tile::index_t  nhead_stride_lse,
ck_tile::index_t  nhead_stride_o,
ck_tile::index_t  window_size_left,
ck_tile::index_t  window_size_right,
ck_tile::index_t  mask_type,
ck_tile::index_t  min_seqlen_q,
float  p_drop,
bool  s_randval,
const std::tuple< const void *, const void * > &  drop_seed_offset 
)
inlinestaticconstexpr

◆ MakeKargs() [4/4]

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
template<bool Cond = kIsGroupMode>
static constexpr CK_TILE_HOST std::enable_if_t<Cond, Kargs> ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void *  q_ptr,
const void *  k_ptr,
const void *  v_ptr,
const void *  bias_ptr,
void *  rand_val_ptr,
void *  lse_ptr,
void *  o_ptr,
const void *  seqstart_q_ptr,
const void *  seqstart_k_ptr,
const void *  seqlen_k_ptr,
ck_tile::index_t  hdim_q,
ck_tile::index_t  hdim_v,
ck_tile::index_t  num_head_q,
ck_tile::index_t  nhead_ratio_qk,
float  scale_s,
float  scale_p,
float  scale_o,
float  logits_soft_cap,
ck_tile::index_t  stride_q,
ck_tile::index_t  stride_k,
ck_tile::index_t  stride_v,
ck_tile::index_t  stride_bias,
ck_tile::index_t  stride_randval,
ck_tile::index_t  stride_o,
ck_tile::index_t  nhead_stride_q,
ck_tile::index_t  nhead_stride_k,
ck_tile::index_t  nhead_stride_v,
ck_tile::index_t  nhead_stride_bias,
ck_tile::index_t  nhead_stride_randval,
ck_tile::index_t  nhead_stride_lse,
ck_tile::index_t  nhead_stride_o,
ck_tile::index_t  window_size_left,
ck_tile::index_t  window_size_right,
ck_tile::index_t  mask_type,
ck_tile::index_t  min_seqlen_q,
float  p_drop,
bool  s_randval,
const std::tuple< uint64_t, uint64_t > &  drop_seed_offset 
)
inlinestaticconstexpr

◆ MakeKargsImpl() [1/2]

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
template<bool Cond = !kIsGroupMode>
static constexpr CK_TILE_HOST std::enable_if_t<Cond, Kargs> ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargsImpl ( const void *  q_ptr,
const void *  k_ptr,
const void *  v_ptr,
const void *  bias_ptr,
void *  rand_val_ptr,
void *  lse_ptr,
void *  o_ptr,
ck_tile::index_t  seqlen_q,
ck_tile::index_t  seqlen_k,
ck_tile::index_t  hdim_q,
ck_tile::index_t  hdim_v,
ck_tile::index_t  num_head_q,
ck_tile::index_t  nhead_ratio_qk,
float  scale_s,
float  scale_p,
float  scale_o,
float  logits_soft_cap,
ck_tile::index_t  stride_q,
ck_tile::index_t  stride_k,
ck_tile::index_t  stride_v,
ck_tile::index_t  stride_bias,
ck_tile::index_t  stride_randval,
ck_tile::index_t  stride_o,
ck_tile::index_t  nhead_stride_q,
ck_tile::index_t  nhead_stride_k,
ck_tile::index_t  nhead_stride_v,
ck_tile::index_t  nhead_stride_bias,
ck_tile::index_t  nhead_stride_randval,
ck_tile::index_t  nhead_stride_lse,
ck_tile::index_t  nhead_stride_o,
ck_tile::index_t  batch_stride_q,
ck_tile::index_t  batch_stride_k,
ck_tile::index_t  batch_stride_v,
ck_tile::index_t  batch_stride_bias,
ck_tile::index_t  batch_stride_randval,
ck_tile::index_t  batch_stride_lse,
ck_tile::index_t  batch_stride_o,
ck_tile::index_t  window_size_left,
ck_tile::index_t  window_size_right,
ck_tile::index_t  mask_type,
float  p_drop,
bool  s_randval,
std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >>  drop_seed_offset 
)
inlinestaticconstexpr

◆ MakeKargsImpl() [2/2]

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
template<bool Cond = kIsGroupMode>
static constexpr CK_TILE_HOST std::enable_if_t<Cond, Kargs> ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargsImpl ( const void *  q_ptr,
const void *  k_ptr,
const void *  v_ptr,
const void *  bias_ptr,
void *  rand_val_ptr,
void *  lse_ptr,
void *  o_ptr,
const void *  seqstart_q_ptr,
const void *  seqstart_k_ptr,
const void *  seqlen_k_ptr,
ck_tile::index_t  hdim_q,
ck_tile::index_t  hdim_v,
ck_tile::index_t  num_head_q,
ck_tile::index_t  nhead_ratio_qk,
float  scale_s,
float  scale_p,
float  scale_o,
float  logits_soft_cap,
ck_tile::index_t  stride_q,
ck_tile::index_t  stride_k,
ck_tile::index_t  stride_v,
ck_tile::index_t  stride_bias,
ck_tile::index_t  stride_randval,
ck_tile::index_t  stride_o,
ck_tile::index_t  nhead_stride_q,
ck_tile::index_t  nhead_stride_k,
ck_tile::index_t  nhead_stride_v,
ck_tile::index_t  nhead_stride_bias,
ck_tile::index_t  nhead_stride_randval,
ck_tile::index_t  nhead_stride_lse,
ck_tile::index_t  nhead_stride_o,
ck_tile::index_t  window_size_left,
ck_tile::index_t  window_size_right,
ck_tile::index_t  mask_type,
ck_tile::index_t  min_seqlen_q,
float  p_drop,
bool  s_randval,
std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >>  drop_seed_offset 
)
inlinestaticconstexpr

◆ operator()()

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
CK_TILE_DEVICE void ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::operator() ( Kargs  kargs) const
inline

◆ run_()

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
CK_TILE_DEVICE void ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::run_ ( Kargs  kargs) const
inline

FIXME: Before C++20, capturing structured binding variables are not supported. Remove following copy capture of the 'i_nhead' if in C++20

FIXME: Before C++20, capturing structured binding variables are not supported. Remove following copy capture of the 'i_nhead' if in C++20

Member Data Documentation

◆ BiasEnum

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr auto ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::BiasEnum = FmhaPipeline::BiasEnum
staticconstexpr

◆ kBlockPerCu

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr ck_tile::index_t ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockPerCu = FmhaPipeline::kBlockPerCu
staticconstexpr

◆ kBlockPerCuInput

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr ck_tile::index_t ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu
staticconstexpr

◆ kBlockSize

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr ck_tile::index_t ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockSize = FmhaPipeline::kBlockSize
staticconstexpr

◆ kDoFp8StaticQuant

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant
staticconstexpr

◆ kHasDropout

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kHasDropout = FmhaPipeline::kHasDropout
staticconstexpr

◆ kHasLogitsSoftCap

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap
staticconstexpr

◆ kHasMask

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kHasMask = FmhaMask::IsMasking
staticconstexpr

◆ kIsAvailable

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kIsAvailable = !kUseTrLoad
staticconstexpr

◆ kIsGroupMode

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kIsGroupMode = FmhaPipeline::kIsGroupMode
staticconstexpr

◆ kPadHeadDimQ

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ
staticconstexpr

◆ kPadHeadDimV

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadHeadDimV = FmhaPipeline::kPadHeadDimV
staticconstexpr

◆ kPadSeqLenK

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadSeqLenK = FmhaPipeline::kPadSeqLenK
staticconstexpr

◆ kPadSeqLenQ

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ
staticconstexpr

◆ kPipelineName

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr std::string_view ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kPipelineName = FmhaPipeline::name
staticconstexpr

◆ kSkipMinSeqlenQ

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ
staticconstexpr

◆ kStoreLSE

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kStoreLSE = FmhaPipeline::kStoreLSE
staticconstexpr

◆ kUseAsyncCopy

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy
staticconstexpr

◆ kUseTrLoad

template<typename FmhaPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad
staticconstexpr

The documentation for this struct was generated from the following file:
  • /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp