include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File#
fmha_fwd_kernel.hpp
Go to the documentation of this file.
89 "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
90 "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
91 "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
92 "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
93 (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
94 "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
95 (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
96 (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
@ ELEMENTWISE_BIAS
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:461
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
@ MASK_FROM_TOP_LEFT
@ FROM_BOTTOM_RIGHT
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_dropout.hpp:26
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_kernel.hpp:153
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_kernel.hpp:156
const void * alibi_slope_ptr
Definition: fmha_fwd_kernel.hpp:155
Definition: fmha_fwd_kernel.hpp:148
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_kernel.hpp:149
Definition: fmha_fwd_kernel.hpp:229
ck_tile::index_t batch_stride_randval
Definition: fmha_fwd_kernel.hpp:230
Definition: fmha_fwd_kernel.hpp:244
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_kernel.hpp:248
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_kernel.hpp:245
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_kernel.hpp:246
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_kernel.hpp:247
Definition: fmha_fwd_kernel.hpp:141
const void * bias_ptr
Definition: fmha_fwd_kernel.hpp:142
ck_tile::index_t stride_bias
Definition: fmha_fwd_kernel.hpp:143
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_kernel.hpp:144
Definition: fmha_fwd_kernel.hpp:194
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_fwd_kernel.hpp:207
float rp_undrop
Definition: fmha_fwd_kernel.hpp:219
ck_tile::index_t stride_randval
Definition: fmha_fwd_kernel.hpp:224
ck_tile::index_t nhead_stride_randval
Definition: fmha_fwd_kernel.hpp:225
void * rand_val_ptr
Definition: fmha_fwd_kernel.hpp:222
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_fwd_kernel.hpp:195
bool is_store_randval
Definition: fmha_fwd_kernel.hpp:221
uint8_t p_undrop_in_uint8_t
Definition: fmha_fwd_kernel.hpp:220
Definition: fmha_fwd_kernel.hpp:112
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_kernel.hpp:135
ck_tile::index_t seqlen_k
Definition: fmha_fwd_kernel.hpp:119
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_kernel.hpp:137
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_kernel.hpp:126
ck_tile::index_t num_head_q
Definition: fmha_fwd_kernel.hpp:123
ck_tile::index_t hdim_q
Definition: fmha_fwd_kernel.hpp:120
const void * v_ptr
Definition: fmha_fwd_kernel.hpp:115
const void * k_ptr
Definition: fmha_fwd_kernel.hpp:114
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_kernel.hpp:134
ck_tile::index_t stride_k
Definition: fmha_fwd_kernel.hpp:130
ck_tile::index_t stride_o
Definition: fmha_fwd_kernel.hpp:132
ck_tile::index_t stride_v
Definition: fmha_fwd_kernel.hpp:131
ck_tile::index_t hdim_v
Definition: fmha_fwd_kernel.hpp:121
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_kernel.hpp:136
const void * q_ptr
Definition: fmha_fwd_kernel.hpp:113
ck_tile::index_t seqlen_q
Definition: fmha_fwd_kernel.hpp:118
ck_tile::index_t stride_q
Definition: fmha_fwd_kernel.hpp:129
Definition: fmha_fwd_kernel.hpp:173
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_kernel.hpp:176
void * lse_ptr
Definition: fmha_fwd_kernel.hpp:174
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_kernel.hpp:175
Definition: fmha_fwd_kernel.hpp:180
bool is_drop_seed_offset_from_host
Definition: fmha_fwd_kernel.hpp:190
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_fwd_kernel.hpp:188
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_fwd_kernel.hpp:189
Definition: fmha_fwd_kernel.hpp:105
Definition: fmha_fwd_kernel.hpp:167
float scale_o
Definition: fmha_fwd_kernel.hpp:169
float scale_p
Definition: fmha_fwd_kernel.hpp:168
Definition: fmha_fwd_kernel.hpp:262
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_kernel.hpp:263
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:265
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_kernel.hpp:264
Definition: fmha_fwd_kernel.hpp:160
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_kernel.hpp:163
ck_tile::index_t window_size_right
Definition: fmha_fwd_kernel.hpp:162
ck_tile::index_t window_size_left
Definition: fmha_fwd_kernel.hpp:162
Definition: fmha_fwd_kernel.hpp:58
Definition: fmha_fwd_kernel.hpp:25
static constexpr bool kHasDropout
Definition: fmha_fwd_kernel.hpp:52
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_kernel.hpp:66
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:711
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:272
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_kernel.hpp:53
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:790
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:406
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_kernel.hpp:34
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_kernel.hpp:268
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:497
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_kernel.hpp:29
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_kernel.hpp:40
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_kernel.hpp:43
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:587
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_kernel.hpp:28
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_kernel.hpp:36
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
Definition: fmha_fwd_kernel.hpp:866
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_fwd_kernel.hpp:940
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_kernel.hpp:35
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_kernel.hpp:31
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_kernel.hpp:49
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_kernel.hpp:891
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_kernel.hpp:39
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_kernel.hpp:33
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_kernel.hpp:942
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_kernel.hpp:26
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_kernel.hpp:41
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_kernel.hpp:46
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_fwd_kernel.hpp:38
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_kernel.hpp:27
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_kernel.hpp:47
static constexpr bool kIsGroupMode
Definition: fmha_fwd_kernel.hpp:45
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:947
Definition: block_dropout.hpp:12
Definition: integral_constant.hpp:13
Definition: functional.hpp:62
Definition: coordinate_transform.hpp:1443
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:52
const T * ptr
Definition: fmha_bwd_kernel.hpp:206
T val
Definition: fmha_bwd_kernel.hpp:205
Definition: fmha_fwd_kernel.hpp:183
T val
Definition: fmha_fwd_kernel.hpp:184
const T * ptr
Definition: fmha_fwd_kernel.hpp:185