include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File#
fmha_bwd_kernel.hpp
Go to the documentation of this file.
29 template <typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_>
96 "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK2) + "x" + _TS_(bfs::kK3) + "x" +
98 "r" + _TS_(gbr0::at(ck_tile::number<0>{})) + "x" + _TS_(gbr0::at(ck_tile::number<1>{})) + "x" + _TS_(gbr0::at(ck_tile::number<2>{})) + "_" +
99 "r" + _TS_(gbr1::at(ck_tile::number<0>{})) + "x" + _TS_(gbr1::at(ck_tile::number<1>{})) + "x" + _TS_(gbr1::at(ck_tile::number<2>{})) + "_" +
100 "r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
101 "w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
102 "w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
104 (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
105 (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) +
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:424
@ atomic_add
@ ELEMENTWISE_BIAS
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
@ MASK_FROM_TOP_LEFT
@ FROM_BOTTOM_RIGHT
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_position_encoding.hpp:137
Definition: fmha_bwd_kernel.hpp:1919
ck_tile::index_t batch_stride_dq
Definition: fmha_bwd_kernel.hpp:1920
ck_tile::index_t batch_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1921
Definition: fmha_bwd_kernel.hpp:1895
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:1899
ck_tile::index_t stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1904
ck_tile::index_t nhead_stride_dq
Definition: fmha_bwd_kernel.hpp:1905
ck_tile::index_t hdim_q
Definition: fmha_bwd_kernel.hpp:1901
ck_tile::index_t stride_dq
Definition: fmha_bwd_kernel.hpp:1903
const void * dq_acc_ptr
Definition: fmha_bwd_kernel.hpp:1896
ck_tile::index_t seqlen_k
Definition: fmha_bwd_kernel.hpp:1900
void * dq_ptr
Definition: fmha_bwd_kernel.hpp:1897
ck_tile::index_t nhead_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1906
Definition: fmha_bwd_kernel.hpp:1910
ck_tile::index_t split_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1911
Definition: fmha_bwd_kernel.hpp:1888
Definition: fmha_bwd_kernel.hpp:1929
const int32_t * seqstart_k_ptr
Definition: fmha_bwd_kernel.hpp:1931
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:1930
Definition: fmha_bwd_kernel.hpp:1859
Definition: fmha_bwd_kernel.hpp:1842
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:1853
static constexpr bool kPadHeadDimQ
Definition: fmha_bwd_kernel.hpp:1855
static constexpr bool kIsDeterministic
Definition: fmha_bwd_kernel.hpp:1856
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:1844
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:1854
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition: fmha_bwd_kernel.hpp:1976
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:1845
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:2014
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t batch_stride_dq, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition: fmha_bwd_kernel.hpp:1940
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:2025
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:2027
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::AccDataType > AccDataType
Definition: fmha_bwd_kernel.hpp:1850
static constexpr ck_tile::index_t kM0
Definition: fmha_bwd_kernel.hpp:1846
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::QGradDataType > QGradDataType
Definition: fmha_bwd_kernel.hpp:1851
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition: fmha_bwd_kernel.hpp:2009
static constexpr ck_tile::index_t kN0
Definition: fmha_bwd_kernel.hpp:1847
ck_tile::remove_cvref_t< FmhaBwdConvertQGrad_ > FmhaBwdConvertQGrad
Definition: fmha_bwd_kernel.hpp:1843
std::conditional_t< kIsGroupMode, FmhaBwdConvertQGradGroupModeKargs, FmhaBwdConvertQGradBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:1936
static constexpr ck_tile::index_t kQKHeaddim
Definition: fmha_bwd_kernel.hpp:1848
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:2023
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:1864
Definition: fmha_bwd_kernel.hpp:176
ck_tile::index_t alibi_slope_stride
Definition: fmha_bwd_kernel.hpp:179
const void * alibi_slope_ptr
Definition: fmha_bwd_kernel.hpp:178
Definition: fmha_bwd_kernel.hpp:190
ck_tile::index_t batch_stride_dbias
Definition: fmha_bwd_kernel.hpp:191
Definition: fmha_bwd_kernel.hpp:171
ck_tile::index_t batch_stride_bias
Definition: fmha_bwd_kernel.hpp:172
Definition: fmha_bwd_kernel.hpp:255
ck_tile::index_t batch_stride_randval
Definition: fmha_bwd_kernel.hpp:256
Definition: fmha_bwd_kernel.hpp:275
ck_tile::index_t batch_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:281
ck_tile::index_t batch_stride_v
Definition: fmha_bwd_kernel.hpp:278
ck_tile::index_t batch_stride_k
Definition: fmha_bwd_kernel.hpp:277
ck_tile::index_t batch_stride_lsed
Definition: fmha_bwd_kernel.hpp:280
ck_tile::index_t batch_stride_dv
Definition: fmha_bwd_kernel.hpp:283
ck_tile::index_t batch_stride_dk
Definition: fmha_bwd_kernel.hpp:282
ck_tile::index_t batch_stride_do
Definition: fmha_bwd_kernel.hpp:279
ck_tile::index_t batch_stride_q
Definition: fmha_bwd_kernel.hpp:276
Definition: fmha_bwd_kernel.hpp:183
ck_tile::index_t stride_dbias
Definition: fmha_bwd_kernel.hpp:185
void * dbias_ptr
Definition: fmha_bwd_kernel.hpp:184
ck_tile::index_t nhead_stride_dbias
Definition: fmha_bwd_kernel.hpp:186
Definition: fmha_bwd_kernel.hpp:164
ck_tile::index_t stride_bias
Definition: fmha_bwd_kernel.hpp:166
const void * bias_ptr
Definition: fmha_bwd_kernel.hpp:165
ck_tile::index_t nhead_stride_bias
Definition: fmha_bwd_kernel.hpp:167
Definition: fmha_bwd_kernel.hpp:215
ck_tile::index_t stride_randval
Definition: fmha_bwd_kernel.hpp:250
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
Definition: fmha_bwd_kernel.hpp:216
float scale_rp_undrop
Definition: fmha_bwd_kernel.hpp:246
float rp_undrop
Definition: fmha_bwd_kernel.hpp:245
ck_tile::index_t nhead_stride_randval
Definition: fmha_bwd_kernel.hpp:251
uint8_t p_undrop_in_uint8_t
Definition: fmha_bwd_kernel.hpp:247
void * rand_val_ptr
Definition: fmha_bwd_kernel.hpp:248
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr, float raw_scale)
Definition: fmha_bwd_kernel.hpp:229
Definition: fmha_bwd_kernel.hpp:122
float raw_scale
Definition: fmha_bwd_kernel.hpp:142
void * dq_acc_ptr
Definition: fmha_bwd_kernel.hpp:129
ck_tile::index_t nhead_stride_do
Definition: fmha_bwd_kernel.hpp:156
ck_tile::index_t stride_dk
Definition: fmha_bwd_kernel.hpp:150
ck_tile::index_t hdim_v
Definition: fmha_bwd_kernel.hpp:136
void * dv_ptr
Definition: fmha_bwd_kernel.hpp:131
ck_tile::index_t nhead_stride_dv
Definition: fmha_bwd_kernel.hpp:160
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:133
ck_tile::index_t nhead_stride_lsed
Definition: fmha_bwd_kernel.hpp:157
ck_tile::index_t nhead_stride_q
Definition: fmha_bwd_kernel.hpp:153
ck_tile::index_t nhead_stride_dk
Definition: fmha_bwd_kernel.hpp:159
ck_tile::index_t stride_v
Definition: fmha_bwd_kernel.hpp:147
ck_tile::index_t stride_dq_acc
Definition: fmha_bwd_kernel.hpp:149
ck_tile::index_t num_head_q
Definition: fmha_bwd_kernel.hpp:140
ck_tile::index_t nhead_stride_k
Definition: fmha_bwd_kernel.hpp:154
ck_tile::index_t nhead_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:158
const void * lse_ptr
Definition: fmha_bwd_kernel.hpp:126
const void * d_ptr
Definition: fmha_bwd_kernel.hpp:128
ck_tile::index_t nhead_ratio_qk
Definition: fmha_bwd_kernel.hpp:141
ck_tile::index_t stride_k
Definition: fmha_bwd_kernel.hpp:146
const void * do_ptr
Definition: fmha_bwd_kernel.hpp:127
float scale
Definition: fmha_bwd_kernel.hpp:143
void * dk_ptr
Definition: fmha_bwd_kernel.hpp:130
const void * v_ptr
Definition: fmha_bwd_kernel.hpp:125
ck_tile::index_t stride_do
Definition: fmha_bwd_kernel.hpp:148
const void * k_ptr
Definition: fmha_bwd_kernel.hpp:124
const void * q_ptr
Definition: fmha_bwd_kernel.hpp:123
ck_tile::index_t nhead_stride_v
Definition: fmha_bwd_kernel.hpp:155
ck_tile::index_t stride_dv
Definition: fmha_bwd_kernel.hpp:151
ck_tile::index_t stride_q
Definition: fmha_bwd_kernel.hpp:145
ck_tile::index_t seqlen_k
Definition: fmha_bwd_kernel.hpp:134
ck_tile::index_t hdim_q
Definition: fmha_bwd_kernel.hpp:135
Definition: fmha_bwd_kernel.hpp:260
ck_tile::index_t split_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:261
Definition: fmha_bwd_kernel.hpp:201
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_bwd_kernel.hpp:209
bool is_drop_seed_offset_from_host
Definition: fmha_bwd_kernel.hpp:211
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_bwd_kernel.hpp:210
Definition: fmha_bwd_kernel.hpp:115
Definition: fmha_bwd_kernel.hpp:297
const int32_t * seqstart_k_ptr
Definition: fmha_bwd_kernel.hpp:299
const int32_t * seqlen_k_ptr
Definition: fmha_bwd_kernel.hpp:300
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:298
Definition: fmha_bwd_kernel.hpp:195
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_bwd_kernel.hpp:197
ck_tile::index_t window_size_right
Definition: fmha_bwd_kernel.hpp:196
ck_tile::index_t window_size_left
Definition: fmha_bwd_kernel.hpp:196
Definition: fmha_bwd_kernel.hpp:69
Definition: fmha_bwd_kernel.hpp:31
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:307
static constexpr bool kHasDropout
Definition: fmha_bwd_kernel.hpp:64
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:965
ck_tile::remove_cvref_t< typename FmhaPipeline::DDataType > DDataType
Definition: fmha_bwd_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_bwd_kernel.hpp:39
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasGradDataType > BiasGradDataType
Definition: fmha_bwd_kernel.hpp:52
static constexpr bool kHasBiasGrad
Definition: fmha_bwd_kernel.hpp:60
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_bwd_kernel.hpp:61
ck_tile::remove_cvref_t< typename FmhaPipeline::QGradDataType > QGradDataType
Definition: fmha_bwd_kernel.hpp:49
ck_tile::remove_cvref_t< typename FmhaPipeline::OGradDataType > OGradDataType
Definition: fmha_bwd_kernel.hpp:48
static constexpr bool kIsDeterministic
Definition: fmha_bwd_kernel.hpp:66
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:597
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:36
static constexpr bool kPadSeqLenK
Definition: fmha_bwd_kernel.hpp:56
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_bwd_kernel.hpp:32
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:864
static constexpr bool kPadHeadDimV
Definition: fmha_bwd_kernel.hpp:58
ck_tile::remove_cvref_t< typename FmhaPipeline::KGradDataType > KGradDataType
Definition: fmha_bwd_kernel.hpp:50
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:476
std::conditional_t< kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:303
static constexpr bool kPadHeadDimQ
Definition: fmha_bwd_kernel.hpp:57
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:1070
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaDropout > FmhaDropout
Definition: fmha_bwd_kernel.hpp:62
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_bwd_kernel.hpp:43
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_bwd_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::AccDataType > AccDataType
Definition: fmha_bwd_kernel.hpp:44
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:74
ck_tile::remove_cvref_t< KGradEpiloguePipeline_ > KGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:33
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:1088
ck_tile::remove_cvref_t< VGradEpiloguePipeline_ > VGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:34
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:717
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:1079
static constexpr bool kHasMask
Definition: fmha_bwd_kernel.hpp:63
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
Definition: fmha_bwd_kernel.hpp:1064
static constexpr auto BiasEnum
Definition: fmha_bwd_kernel.hpp:59
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_bwd_kernel.hpp:47
ck_tile::remove_cvref_t< typename FmhaPipeline::GemmDataType > GemmDataType
Definition: fmha_bwd_kernel.hpp:42
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:55
ck_tile::remove_cvref_t< typename FmhaPipeline::VGradDataType > VGradDataType
Definition: fmha_bwd_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_bwd_kernel.hpp:41
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:35
static constexpr bool kIsStoreRandval
Definition: fmha_bwd_kernel.hpp:65
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_bwd_kernel.hpp:40
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:54
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:1081
Definition: fmha_bwd_kernel.hpp:1652
ck_tile::index_t batch_stride_o
Definition: fmha_bwd_kernel.hpp:1654
ck_tile::index_t batch_stride_do
Definition: fmha_bwd_kernel.hpp:1653
ck_tile::index_t batch_stride_d
Definition: fmha_bwd_kernel.hpp:1655
Definition: fmha_bwd_kernel.hpp:1633
void * d_ptr
Definition: fmha_bwd_kernel.hpp:1636
const void * o_ptr
Definition: fmha_bwd_kernel.hpp:1634
ck_tile::index_t hdim_v
Definition: fmha_bwd_kernel.hpp:1641
ck_tile::index_t nhead_stride_do
Definition: fmha_bwd_kernel.hpp:1646
ck_tile::index_t stride_o
Definition: fmha_bwd_kernel.hpp:1644
ck_tile::index_t nhead_stride_o
Definition: fmha_bwd_kernel.hpp:1647
const void * do_ptr
Definition: fmha_bwd_kernel.hpp:1635
ck_tile::index_t stride_do
Definition: fmha_bwd_kernel.hpp:1643
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:1640
float p_undrop
Definition: fmha_bwd_kernel.hpp:1638
ck_tile::index_t nhead_stride_d
Definition: fmha_bwd_kernel.hpp:1648
Definition: fmha_bwd_kernel.hpp:1659
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:1660
Definition: fmha_bwd_kernel.hpp:1603
Definition: fmha_bwd_kernel.hpp:1587
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::ODataType > ODataType
Definition: fmha_bwd_kernel.hpp:1595
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:1748
ck_tile::remove_cvref_t< FmhaBwdOGradDotO_ > FmhaBwdOGradDotO
Definition: fmha_bwd_kernel.hpp:1588
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_d)
Definition: fmha_bwd_kernel.hpp:1668
static constexpr ck_tile::index_t kVHeaddim
Definition: fmha_bwd_kernel.hpp:1592
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:1750
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:1598
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::OGradDataType > OGradDataType
Definition: fmha_bwd_kernel.hpp:1596
static constexpr ck_tile::index_t kM0
Definition: fmha_bwd_kernel.hpp:1591
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, const void *seqstart_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d)
Definition: fmha_bwd_kernel.hpp:1703
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:1589
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:1590
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:1737
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:1599
std::conditional_t< kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:1664
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:1608
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::DDataType > DDataType
Definition: fmha_bwd_kernel.hpp:1594
static constexpr bool kPadHeadDimV
Definition: fmha_bwd_kernel.hpp:1600
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:1746
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition: fmha_bwd_kernel.hpp:1732
Definition: integral_constant.hpp:13
Definition: coordinate_transform.hpp:1443
Definition: sequence.hpp:52
Definition: fmha_bwd_kernel.hpp:204
const T * ptr
Definition: fmha_bwd_kernel.hpp:206
T val
Definition: fmha_bwd_kernel.hpp:205