/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File#
fmha_bwd_kernel.hpp
Go to the documentation of this file.
110 "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK2) + "x" + _TS_(bfs::kK3) + "x" +
112 "r" + _TS_(gbr0::at(ck_tile::number<0>{})) + "x" + _TS_(gbr0::at(ck_tile::number<1>{})) + "x" + _TS_(gbr0::at(ck_tile::number<2>{})) + "_" +
113 "r" + _TS_(gbr1::at(ck_tile::number<0>{})) + "x" + _TS_(gbr1::at(ck_tile::number<1>{})) + "x" + _TS_(gbr1::at(ck_tile::number<2>{})) + "_" +
114 "r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
115 "w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
116 "w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
120 (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
121 (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? gwt0::at(ck_tile::number<0>{}) == 16? "_dropout_wg16":"_dropout_wg32" : "_ndropout" ) +
122 (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload");
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
@ ELEMENTWISE_BIAS
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
@ MASK_FROM_TOP_LEFT
@ FROM_BOTTOM_RIGHT
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__device__ X atomic_add(X *p_dst, const X &x)
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_position_encoding.hpp:137
Definition: fmha_bwd_kernel.hpp:1511
ck_tile::index_t batch_stride_dq
Definition: fmha_bwd_kernel.hpp:1512
ck_tile::index_t batch_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1513
Definition: fmha_bwd_kernel.hpp:1487
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:1491
ck_tile::index_t stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1496
ck_tile::index_t nhead_stride_dq
Definition: fmha_bwd_kernel.hpp:1497
ck_tile::index_t hdim_q
Definition: fmha_bwd_kernel.hpp:1493
ck_tile::index_t stride_dq
Definition: fmha_bwd_kernel.hpp:1495
const void * dq_acc_ptr
Definition: fmha_bwd_kernel.hpp:1488
ck_tile::index_t seqlen_k
Definition: fmha_bwd_kernel.hpp:1492
void * dq_ptr
Definition: fmha_bwd_kernel.hpp:1489
ck_tile::index_t nhead_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1498
Definition: fmha_bwd_kernel.hpp:1502
ck_tile::index_t split_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1503
Definition: fmha_bwd_kernel.hpp:1480
Definition: fmha_bwd_kernel.hpp:1521
const int32_t * seqstart_k_ptr
Definition: fmha_bwd_kernel.hpp:1523
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:1522
Definition: fmha_bwd_kernel.hpp:1447
Definition: fmha_bwd_kernel.hpp:1430
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:1441
static constexpr bool kPadHeadDimQ
Definition: fmha_bwd_kernel.hpp:1443
static constexpr bool kIsDeterministic
Definition: fmha_bwd_kernel.hpp:1444
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:1432
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:1442
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition: fmha_bwd_kernel.hpp:1568
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:1433
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:1606
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t batch_stride_dq, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition: fmha_bwd_kernel.hpp:1532
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:1617
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:1619
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::AccDataType > AccDataType
Definition: fmha_bwd_kernel.hpp:1438
static constexpr ck_tile::index_t kM0
Definition: fmha_bwd_kernel.hpp:1434
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::QGradDataType > QGradDataType
Definition: fmha_bwd_kernel.hpp:1439
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition: fmha_bwd_kernel.hpp:1601
static constexpr ck_tile::index_t kN0
Definition: fmha_bwd_kernel.hpp:1435
ck_tile::remove_cvref_t< FmhaBwdConvertQGrad_ > FmhaBwdConvertQGrad
Definition: fmha_bwd_kernel.hpp:1431
std::conditional_t< kIsGroupMode, FmhaBwdConvertQGradGroupModeKargs, FmhaBwdConvertQGradBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:1528
static constexpr ck_tile::index_t kQKHeaddim
Definition: fmha_bwd_kernel.hpp:1436
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:1615
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:1453
Definition: fmha_bwd_kernel.hpp:192
const void * alibi_slope_ptr
Definition: fmha_bwd_kernel.hpp:194
ck_tile::index_t alibi_slope_stride
Definition: fmha_bwd_kernel.hpp:195
Definition: fmha_bwd_kernel.hpp:206
ck_tile::index_t batch_stride_dbias
Definition: fmha_bwd_kernel.hpp:207
Definition: fmha_bwd_kernel.hpp:187
ck_tile::index_t batch_stride_bias
Definition: fmha_bwd_kernel.hpp:188
Definition: fmha_bwd_kernel.hpp:271
ck_tile::index_t batch_stride_randval
Definition: fmha_bwd_kernel.hpp:272
Definition: fmha_bwd_kernel.hpp:291
ck_tile::index_t batch_stride_v
Definition: fmha_bwd_kernel.hpp:294
ck_tile::index_t batch_stride_k
Definition: fmha_bwd_kernel.hpp:293
ck_tile::index_t batch_stride_q
Definition: fmha_bwd_kernel.hpp:292
ck_tile::index_t batch_stride_do
Definition: fmha_bwd_kernel.hpp:295
ck_tile::index_t batch_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:297
ck_tile::index_t batch_stride_dk
Definition: fmha_bwd_kernel.hpp:298
ck_tile::index_t batch_stride_dv
Definition: fmha_bwd_kernel.hpp:299
ck_tile::index_t batch_stride_lsed
Definition: fmha_bwd_kernel.hpp:296
Definition: fmha_bwd_kernel.hpp:199
ck_tile::index_t nhead_stride_dbias
Definition: fmha_bwd_kernel.hpp:202
void * dbias_ptr
Definition: fmha_bwd_kernel.hpp:200
ck_tile::index_t stride_dbias
Definition: fmha_bwd_kernel.hpp:201
Definition: fmha_bwd_kernel.hpp:180
ck_tile::index_t stride_bias
Definition: fmha_bwd_kernel.hpp:182
ck_tile::index_t nhead_stride_bias
Definition: fmha_bwd_kernel.hpp:183
const void * bias_ptr
Definition: fmha_bwd_kernel.hpp:181
Definition: fmha_bwd_kernel.hpp:231
uint8_t p_undrop_in_uint8_t
Definition: fmha_bwd_kernel.hpp:263
float rp_undrop
Definition: fmha_bwd_kernel.hpp:261
ck_tile::index_t nhead_stride_randval
Definition: fmha_bwd_kernel.hpp:267
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr, float raw_scale)
Definition: fmha_bwd_kernel.hpp:245
float scale_rp_undrop
Definition: fmha_bwd_kernel.hpp:262
void * rand_val_ptr
Definition: fmha_bwd_kernel.hpp:264
ck_tile::index_t stride_randval
Definition: fmha_bwd_kernel.hpp:266
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
Definition: fmha_bwd_kernel.hpp:232
Definition: fmha_bwd_kernel.hpp:138
ck_tile::index_t nhead_stride_dk
Definition: fmha_bwd_kernel.hpp:175
ck_tile::index_t stride_do
Definition: fmha_bwd_kernel.hpp:164
ck_tile::index_t seqlen_k
Definition: fmha_bwd_kernel.hpp:150
const void * q_ptr
Definition: fmha_bwd_kernel.hpp:139
ck_tile::index_t hdim_q
Definition: fmha_bwd_kernel.hpp:151
ck_tile::index_t nhead_stride_do
Definition: fmha_bwd_kernel.hpp:172
ck_tile::index_t num_head_q
Definition: fmha_bwd_kernel.hpp:156
const void * lse_ptr
Definition: fmha_bwd_kernel.hpp:142
float raw_scale
Definition: fmha_bwd_kernel.hpp:158
ck_tile::index_t nhead_stride_k
Definition: fmha_bwd_kernel.hpp:170
ck_tile::index_t nhead_stride_q
Definition: fmha_bwd_kernel.hpp:169
ck_tile::index_t stride_dv
Definition: fmha_bwd_kernel.hpp:167
ck_tile::index_t nhead_stride_lsed
Definition: fmha_bwd_kernel.hpp:173
void * dq_acc_ptr
Definition: fmha_bwd_kernel.hpp:145
ck_tile::index_t stride_q
Definition: fmha_bwd_kernel.hpp:161
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:149
ck_tile::index_t stride_dk
Definition: fmha_bwd_kernel.hpp:166
const void * do_ptr
Definition: fmha_bwd_kernel.hpp:143
float scale
Definition: fmha_bwd_kernel.hpp:159
void * dk_ptr
Definition: fmha_bwd_kernel.hpp:146
ck_tile::index_t nhead_stride_v
Definition: fmha_bwd_kernel.hpp:171
ck_tile::index_t stride_v
Definition: fmha_bwd_kernel.hpp:163
const void * d_ptr
Definition: fmha_bwd_kernel.hpp:144
const void * k_ptr
Definition: fmha_bwd_kernel.hpp:140
ck_tile::index_t nhead_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:174
ck_tile::index_t nhead_ratio_qk
Definition: fmha_bwd_kernel.hpp:157
void * dv_ptr
Definition: fmha_bwd_kernel.hpp:147
const void * v_ptr
Definition: fmha_bwd_kernel.hpp:141
ck_tile::index_t hdim_v
Definition: fmha_bwd_kernel.hpp:152
ck_tile::index_t stride_k
Definition: fmha_bwd_kernel.hpp:162
ck_tile::index_t nhead_stride_dv
Definition: fmha_bwd_kernel.hpp:176
ck_tile::index_t stride_dq_acc
Definition: fmha_bwd_kernel.hpp:165
Definition: fmha_bwd_kernel.hpp:276
ck_tile::index_t split_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:277
Definition: fmha_bwd_kernel.hpp:217
bool is_drop_seed_offset_from_host
Definition: fmha_bwd_kernel.hpp:227
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_bwd_kernel.hpp:225
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_bwd_kernel.hpp:226
Definition: fmha_bwd_kernel.hpp:131
Definition: fmha_bwd_kernel.hpp:313
const int32_t * seqstart_k_ptr
Definition: fmha_bwd_kernel.hpp:315
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:314
const int32_t * seqlen_k_ptr
Definition: fmha_bwd_kernel.hpp:316
Definition: fmha_bwd_kernel.hpp:211
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_bwd_kernel.hpp:213
ck_tile::index_t window_size_right
Definition: fmha_bwd_kernel.hpp:212
ck_tile::index_t window_size_left
Definition: fmha_bwd_kernel.hpp:212
Definition: fmha_bwd_kernel.hpp:84
Definition: fmha_bwd_kernel.hpp:35
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:509
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:671
ck_tile::remove_cvref_t< KGradEpiloguePipeline_ > KGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:37
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:341
static constexpr auto BiasEnum
Definition: fmha_bwd_kernel.hpp:66
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:90
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaDropout > FmhaDropout
Definition: fmha_bwd_kernel.hpp:69
ck_tile::remove_cvref_t< typename FmhaPipeline::OGradDataType > OGradDataType
Definition: fmha_bwd_kernel.hpp:57
static constexpr bool kUseQrQtrDorPipeline
Definition: fmha_bwd_kernel.hpp:42
ck_tile::remove_cvref_t< VGradEpiloguePipeline_ > VGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasGradDataType > BiasGradDataType
Definition: fmha_bwd_kernel.hpp:61
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:673
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:40
static constexpr bool kHasMask
Definition: fmha_bwd_kernel.hpp:70
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_bwd_kernel.hpp:48
static constexpr CK_TILE_HOST Kargs MakeKargs(Ts... args, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:324
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:63
ck_tile::remove_cvref_t< typename FmhaPipeline::QGradDataType > QGradDataType
Definition: fmha_bwd_kernel.hpp:58
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:662
ck_tile::remove_cvref_t< typename FmhaPipeline::DDataType > DDataType
Definition: fmha_bwd_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_bwd_kernel.hpp:47
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_bwd_kernel.hpp:52
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
Definition: fmha_bwd_kernel.hpp:654
ck_tile::remove_cvref_t< typename FmhaPipeline::VGradDataType > VGradDataType
Definition: fmha_bwd_kernel.hpp:60
ck_tile::remove_cvref_t< typename FmhaPipeline::GemmDataType > GemmDataType
Definition: fmha_bwd_kernel.hpp:51
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:680
static constexpr bool kIsAvailable
Definition: fmha_bwd_kernel.hpp:80
ck_tile::remove_cvref_t< QGradEpiloguePipeline_ > QGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:39
static constexpr bool kHasDropout
Definition: fmha_bwd_kernel.hpp:71
ck_tile::remove_cvref_t< typename FmhaPipeline::AccDataType > AccDataType
Definition: fmha_bwd_kernel.hpp:53
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:686
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_bwd_kernel.hpp:49
static constexpr bool kHasBiasGrad
Definition: fmha_bwd_kernel.hpp:67
static constexpr bool kIsDeterministic
Definition: fmha_bwd_kernel.hpp:73
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_bwd_kernel.hpp:50
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_bwd_kernel.hpp:56
static constexpr index_t kPadHeadDimQ
Definition: fmha_bwd_kernel.hpp:64
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_bwd_kernel.hpp:36
static constexpr CK_TILE_HOST Kargs MakeKargs(Ts... args, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:333
ck_tile::remove_cvref_t< typename FmhaPipeline::KGradDataType > KGradDataType
Definition: fmha_bwd_kernel.hpp:59
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_bwd_kernel.hpp:68
static constexpr bool kUseTrLoad
Definition: fmha_bwd_kernel.hpp:74
static constexpr index_t kPadHeadDimV
Definition: fmha_bwd_kernel.hpp:65
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:41
std::conditional_t< kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:319
static constexpr bool kIsStoreRandval
Definition: fmha_bwd_kernel.hpp:72
static constexpr index_t kMaxSeqLenQ
Definition: fmha_bwd_kernel.hpp:75
Definition: fmha_bwd_kernel.hpp:1240
ck_tile::index_t batch_stride_o
Definition: fmha_bwd_kernel.hpp:1242
ck_tile::index_t batch_stride_do
Definition: fmha_bwd_kernel.hpp:1241
ck_tile::index_t batch_stride_d
Definition: fmha_bwd_kernel.hpp:1243
Definition: fmha_bwd_kernel.hpp:1221
void * d_ptr
Definition: fmha_bwd_kernel.hpp:1224
const void * o_ptr
Definition: fmha_bwd_kernel.hpp:1222
ck_tile::index_t hdim_v
Definition: fmha_bwd_kernel.hpp:1229
ck_tile::index_t nhead_stride_do
Definition: fmha_bwd_kernel.hpp:1234
ck_tile::index_t stride_o
Definition: fmha_bwd_kernel.hpp:1232
ck_tile::index_t nhead_stride_o
Definition: fmha_bwd_kernel.hpp:1235
const void * do_ptr
Definition: fmha_bwd_kernel.hpp:1223
ck_tile::index_t stride_do
Definition: fmha_bwd_kernel.hpp:1231
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:1228
float p_undrop
Definition: fmha_bwd_kernel.hpp:1226
ck_tile::index_t nhead_stride_d
Definition: fmha_bwd_kernel.hpp:1236
Definition: fmha_bwd_kernel.hpp:1247
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:1248
Definition: fmha_bwd_kernel.hpp:1190
Definition: fmha_bwd_kernel.hpp:1174
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::ODataType > ODataType
Definition: fmha_bwd_kernel.hpp:1182
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:1336
ck_tile::remove_cvref_t< FmhaBwdOGradDotO_ > FmhaBwdOGradDotO
Definition: fmha_bwd_kernel.hpp:1175
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_d)
Definition: fmha_bwd_kernel.hpp:1256
static constexpr ck_tile::index_t kVHeaddim
Definition: fmha_bwd_kernel.hpp:1179
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:1338
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:1185
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::OGradDataType > OGradDataType
Definition: fmha_bwd_kernel.hpp:1183
static constexpr ck_tile::index_t kM0
Definition: fmha_bwd_kernel.hpp:1178
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, const void *seqstart_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d)
Definition: fmha_bwd_kernel.hpp:1291
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:1176
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:1177
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:1325
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:1186
std::conditional_t< kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:1252
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:1196
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::DDataType > DDataType
Definition: fmha_bwd_kernel.hpp:1181
static constexpr bool kPadHeadDimV
Definition: fmha_bwd_kernel.hpp:1187
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:1334
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition: fmha_bwd_kernel.hpp:1320
Definition: integral_constant.hpp:13
Definition: block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:777
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
T val
Definition: fmha_batch_prefill_kernel.hpp:223
const T * ptr
Definition: fmha_batch_prefill_kernel.hpp:224
Definition: fmha_bwd_kernel.hpp:220
T val
Definition: fmha_bwd_kernel.hpp:221
const T * ptr
Definition: fmha_bwd_kernel.hpp:222