/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp Source File#
fmha_fwd_appendkv_kernel.hpp
Go to the documentation of this file.
61 "b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" +
62 _TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
63 "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
64 + (!kApplyRoPE ? _SS_("") : (_SS_("_") + RotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name))
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:333
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: fmha_fwd_appendkv_kernel.hpp:81
ck_tile::index_t stride_q
Definition: fmha_fwd_appendkv_kernel.hpp:101
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:88
const void * knew_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:84
ck_tile::index_t stride_k
Definition: fmha_fwd_appendkv_kernel.hpp:102
ck_tile::index_t batch_stride_knew
Definition: fmha_fwd_appendkv_kernel.hpp:115
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_appendkv_kernel.hpp:116
ck_tile::index_t nhead_stride_knew
Definition: fmha_fwd_appendkv_kernel.hpp:109
ck_tile::index_t nhead_stride_vnew
Definition: fmha_fwd_appendkv_kernel.hpp:111
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_appendkv_kernel.hpp:114
ck_tile::index_t stride_knew
Definition: fmha_fwd_appendkv_kernel.hpp:103
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_appendkv_kernel.hpp:107
ck_tile::index_t hdim_q
Definition: fmha_fwd_appendkv_kernel.hpp:93
ck_tile::index_t stride_v
Definition: fmha_fwd_appendkv_kernel.hpp:104
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_appendkv_kernel.hpp:113
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_appendkv_kernel.hpp:108
ck_tile::index_t hdim_v
Definition: fmha_fwd_appendkv_kernel.hpp:94
ck_tile::index_t batch_stride_vnew
Definition: fmha_fwd_appendkv_kernel.hpp:117
void * q_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:82
ck_tile::index_t stride_vnew
Definition: fmha_fwd_appendkv_kernel.hpp:105
void * v_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:85
const void * vnew_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:86
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_appendkv_kernel.hpp:110
void * k_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:83
ck_tile::index_t seqlen_k
Definition: fmha_fwd_appendkv_kernel.hpp:91
ck_tile::index_t seqlen_q
Definition: fmha_fwd_appendkv_kernel.hpp:90
ck_tile::index_t seqlen_knew
Definition: fmha_fwd_appendkv_kernel.hpp:92
ck_tile::index_t num_head_q
Definition: fmha_fwd_appendkv_kernel.hpp:96
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_appendkv_kernel.hpp:99
Definition: fmha_fwd_appendkv_kernel.hpp:136
const int32_t * cache_batch_idx
Definition: fmha_fwd_appendkv_kernel.hpp:137
Definition: fmha_fwd_appendkv_kernel.hpp:74
Definition: fmha_fwd_appendkv_kernel.hpp:143
Definition: fmha_fwd_appendkv_kernel.hpp:129
ck_tile::index_t batch_stride_block_table
Definition: fmha_fwd_appendkv_kernel.hpp:131
const int32_t * block_table_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:130
ck_tile::index_t page_block_size
Definition: fmha_fwd_appendkv_kernel.hpp:132
Definition: fmha_fwd_appendkv_kernel.hpp:121
ck_tile::index_t rotary_dim
Definition: fmha_fwd_appendkv_kernel.hpp:124
const void * rotary_sin_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:123
bool has_mask
Definition: fmha_fwd_appendkv_kernel.hpp:125
const void * rotary_cos_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:122
Definition: fmha_fwd_appendkv_kernel.hpp:37
Definition: fmha_fwd_appendkv_kernel.hpp:15
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_appendkv_kernel.hpp:34
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_appendkv_kernel.hpp:21
static constexpr __host__ auto BlockSize()
Definition: fmha_fwd_appendkv_kernel.hpp:258
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_appendkv_kernel.hpp:16
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_appendkv_kernel.hpp:23
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_appendkv_kernel.hpp:24
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_knew)
Definition: fmha_fwd_appendkv_kernel.hpp:237
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_appendkv_kernel.hpp:31
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_appendkv_kernel.hpp:260
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_appendkv_kernel.hpp:27
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_appendkv_kernel.hpp:18
static constexpr __host__ Kargs MakeKargs(void *q_ptr, void *k_ptr, const void *knew_ptr, void *v_ptr, const void *vnew_ptr, ck_tile::index_t seqlen_q, const void *seqlen_k_ptr, ck_tile::index_t seqlen_knew, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *rotary_cos_ptr, const void *rotary_sin_ptr, ck_tile::index_t rotary_dim, bool has_mask, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_knew, ck_tile::index_t stride_v, ck_tile::index_t stride_vnew, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_knew, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_vnew, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_knew, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_vnew)
Definition: fmha_fwd_appendkv_kernel.hpp:146
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_appendkv_kernel.hpp:25
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_appendkv_kernel.hpp:33
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_appendkv_kernel.hpp:32
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &)
Definition: fmha_fwd_appendkv_kernel.hpp:249
static constexpr bool kIsPagedKV
Definition: fmha_fwd_appendkv_kernel.hpp:29
static constexpr bool kApplyRoPE
Definition: fmha_fwd_appendkv_kernel.hpp:28
static __host__ std::string GetName()
Definition: fmha_fwd_appendkv_kernel.hpp:45
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_appendkv_kernel.hpp:17
Definition: block_rotary_embedding.hpp:19
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49