/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp Source File#
fmha_fwd_v3_kernel.hpp
Go to the documentation of this file.
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
@ MASK_FROM_TOP_LEFT
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: fmha_fwd_v3_kernel.hpp:158
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_v3_kernel.hpp:160
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_v3_kernel.hpp:161
ck_tile::index_t batch_idx
Definition: fmha_fwd_v3_kernel.hpp:159
Definition: fmha_fwd_v3_kernel.hpp:126
const ck_tile::index_t * cu_seqlen_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:135
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_v3_kernel.hpp:130
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_v3_kernel.hpp:127
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_v3_kernel.hpp:128
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_v3_kernel.hpp:129
const ck_tile::index_t * cu_seqlen_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:134
Definition: fmha_fwd_v3_kernel.hpp:56
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_v3_kernel.hpp:80
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_v3_kernel.hpp:78
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_v3_kernel.hpp:70
ck_tile::index_t stride_o
Definition: fmha_fwd_v3_kernel.hpp:76
ck_tile::index_t seqlen_k
Definition: fmha_fwd_v3_kernel.hpp:63
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_v3_kernel.hpp:81
ck_tile::index_t hdim_q
Definition: fmha_fwd_v3_kernel.hpp:64
ck_tile::index_t seqlen_q
Definition: fmha_fwd_v3_kernel.hpp:62
ck_tile::index_t stride_v
Definition: fmha_fwd_v3_kernel.hpp:75
const void * q_ptr
Definition: fmha_fwd_v3_kernel.hpp:57
float scale_s
Definition: fmha_fwd_v3_kernel.hpp:71
const void * v_ptr
Definition: fmha_fwd_v3_kernel.hpp:59
void * o_ptr
Definition: fmha_fwd_v3_kernel.hpp:60
ck_tile::index_t stride_q
Definition: fmha_fwd_v3_kernel.hpp:73
ck_tile::index_t num_head_q
Definition: fmha_fwd_v3_kernel.hpp:67
const void * k_ptr
Definition: fmha_fwd_v3_kernel.hpp:58
ck_tile::index_t hdim_v
Definition: fmha_fwd_v3_kernel.hpp:65
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_v3_kernel.hpp:79
ck_tile::index_t stride_k
Definition: fmha_fwd_v3_kernel.hpp:74
Definition: fmha_fwd_v3_kernel.hpp:93
void * lse_ptr
Definition: fmha_fwd_v3_kernel.hpp:94
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_v3_kernel.hpp:95
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_v3_kernel.hpp:96
Definition: fmha_fwd_v3_kernel.hpp:49
Definition: fmha_fwd_v3_kernel.hpp:143
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:151
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:152
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:147
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:145
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:144
const int32_t * seqlen_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:146
Definition: fmha_fwd_v3_kernel.hpp:100
float logits_soft_cap_rcp
Definition: fmha_fwd_v3_kernel.hpp:118
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_v3_kernel.hpp:103
FmhaFwdLogitsSoftCapKargs()=default
float logits_soft_cap
Definition: fmha_fwd_v3_kernel.hpp:117
Definition: fmha_fwd_v3_kernel.hpp:85
ck_tile::index_t window_size_left
Definition: fmha_fwd_v3_kernel.hpp:87
ck_tile::index_t remap_opt
Definition: fmha_fwd_v3_kernel.hpp:89
ck_tile::index_t window_size_right
Definition: fmha_fwd_v3_kernel.hpp:87
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_v3_kernel.hpp:88
Definition: fmha_fwd_v3_kernel.hpp:20
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_v3_kernel.hpp:35
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_v3_kernel.hpp:27
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_v3_kernel.hpp:436
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_v3_kernel.hpp:166
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_v3_kernel.hpp:36
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_v3_kernel.hpp:42
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_v3_kernel.hpp:431
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_v3_kernel.hpp:23
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_v3_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_v3_kernel.hpp:28
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_v3_kernel.hpp:39
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_v3_kernel.hpp:38
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_fwd_v3_kernel.hpp:429
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_v3_kernel.hpp:21
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_v3_kernel.hpp:252
static constexpr bool kHasMask
Definition: fmha_fwd_v3_kernel.hpp:44
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_v3_kernel.hpp:155
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &)
Definition: fmha_fwd_v3_kernel.hpp:389
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_v3_kernel.hpp:22
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_v3_kernel.hpp:37
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_v3_kernel.hpp:29
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v)
Definition: fmha_fwd_v3_kernel.hpp:332
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_v3_kernel.hpp:24
static constexpr bool kStoreLSE
Definition: fmha_fwd_v3_kernel.hpp:40
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_v3_kernel.hpp:32
static constexpr CK_TILE_DEVICE auto RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
Definition: fmha_fwd_v3_kernel.hpp:354
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_v3_kernel.hpp:30
static constexpr bool kIsGroupMode
Definition: fmha_fwd_v3_kernel.hpp:34
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_v3_kernel.hpp:43
Definition: variants.hpp:63
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49