/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp Source File#
fmha_fwd_splitkv_combine_kernel.hpp
Go to the documentation of this file.
53 _SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s<ODataType>::name) +
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
constexpr CK_TILE_HOST_DEVICE auto integer_divide_floor(X x, Y y)
Definition: math.hpp:143
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: fmha_fwd_splitkv_combine_kernel.hpp:113
ck_tile::index_t batch_stride_lse_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:114
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_splitkv_combine_kernel.hpp:116
ck_tile::index_t batch_stride_o_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:115
Definition: fmha_fwd_splitkv_combine_kernel.hpp:76
ck_tile::index_t row_stride_o
Definition: fmha_fwd_splitkv_combine_kernel.hpp:87
ck_tile::index_t nhead_stride_lse_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:89
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_splitkv_combine_kernel.hpp:91
ck_tile::index_t nhead_stride_o_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:90
ck_tile::index_t hdim_v
Definition: fmha_fwd_splitkv_combine_kernel.hpp:83
ck_tile::index_t row_stride_o_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:86
void * o_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:79
ck_tile::index_t num_splits
Definition: fmha_fwd_splitkv_combine_kernel.hpp:84
const void * lse_acc_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:77
ck_tile::index_t seqlen_q
Definition: fmha_fwd_splitkv_combine_kernel.hpp:82
ck_tile::index_t split_stride_lse_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:93
ck_tile::index_t batch
Definition: fmha_fwd_splitkv_combine_kernel.hpp:81
const void * o_acc_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:78
ck_tile::index_t split_stride_o_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:94
Definition: fmha_fwd_splitkv_combine_kernel.hpp:98
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_splitkv_combine_kernel.hpp:101
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_splitkv_combine_kernel.hpp:100
void * lse_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:99
Definition: fmha_fwd_splitkv_combine_kernel.hpp:69
Definition: fmha_fwd_splitkv_combine_kernel.hpp:105
float scale_o
Definition: fmha_fwd_splitkv_combine_kernel.hpp:106
Definition: fmha_fwd_splitkv_combine_kernel.hpp:123
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:124
Definition: fmha_fwd_splitkv_combine_kernel.hpp:32
Definition: fmha_fwd_splitkv_combine_kernel.hpp:10
static constexpr bool kStoreLSE
Definition: fmha_fwd_splitkv_combine_kernel.hpp:28
remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_splitkv_combine_kernel.hpp:23
static constexpr __host__ std::enable_if_t< Cond, Kargs > MakeKargs(const void *lse_acc_ptr, const void *o_acc_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc)
Definition: fmha_fwd_splitkv_combine_kernel.hpp:189
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_splitkv_combine_kernel.hpp:276
static constexpr __host__ std::enable_if_t< Cond, Kargs > MakeKargs(const void *lse_acc_ptr, const void *o_acc_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc)
Definition: fmha_fwd_splitkv_combine_kernel.hpp:131
static constexpr index_t kBlockPerCuInput
Definition: fmha_fwd_splitkv_combine_kernel.hpp:19
static constexpr bool kIsGroupMode
Definition: fmha_fwd_splitkv_combine_kernel.hpp:25
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v)
Definition: fmha_fwd_splitkv_combine_kernel.hpp:238
static constexpr __host__ auto BlockSize()
Definition: fmha_fwd_splitkv_combine_kernel.hpp:269
static constexpr index_t kNumWarps
Definition: fmha_fwd_splitkv_combine_kernel.hpp:14
remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_splitkv_combine_kernel.hpp:21
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_splitkv_combine_kernel.hpp:26
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_splitkv_combine_kernel.hpp:250
remove_cvref_t< typename FmhaPipeline::OaccDataType > OaccDataType
Definition: fmha_fwd_splitkv_combine_kernel.hpp:22
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_splitkv_combine_kernel.hpp:27
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_splitkv_combine_kernel.hpp:12
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_splitkv_combine_kernel.hpp:29
static __host__ std::string GetName()
Definition: fmha_fwd_splitkv_combine_kernel.hpp:40
static constexpr index_t kBlockSize
Definition: fmha_fwd_splitkv_combine_kernel.hpp:15
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_splitkv_combine_kernel.hpp:271
static constexpr index_t kBlockPerCu
Definition: fmha_fwd_splitkv_combine_kernel.hpp:16
remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_splitkv_combine_kernel.hpp:11
std::conditional_t< kIsGroupMode, GroupModeKargs, BatchModeKargs > Kargs
Definition: fmha_fwd_splitkv_combine_kernel.hpp:127
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:49