ComposedAttention< VARIANT_CODE, UseExp2 > Struct Template Reference

ComposedAttention&lt; VARIANT_CODE, UseExp2 &gt; Struct Template Reference#

Composable Kernel: ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 > Struct Template Reference
ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 > Struct Template Reference

#include <variants.hpp>

Public Member Functions

__device__ __host__ ComposedAttention ()=default
 
template<typename Params , typename T >
__device__ __forceinline__ T QueryTransform (const Params &params, T q) const
 
template<typename Params , typename T >
__device__ __forceinline__ T LogitsTransform (const Params &params, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
 
template<typename Params >
__device__ __forceinline__ bool LogitsMask (const Params &params, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
 

Static Public Attributes

static constexpr bool use_exp2 = UseExp2
 
static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0
 

Constructor & Destructor Documentation

◆ ComposedAttention()

template<uint32_t VARIANT_CODE, bool UseExp2 = false>
__device__ __host__ ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::ComposedAttention ( )
default

Member Function Documentation

◆ LogitsMask()

template<uint32_t VARIANT_CODE, bool UseExp2 = false>
template<typename Params >
__device__ __forceinline__ bool ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::LogitsMask ( const Params &  params,
[[maybe_unused] ] uint32_t  batch_idx,
uint32_t  qo_idx,
uint32_t  kv_idx,
[[maybe_unused] ] uint32_t  qo_head_idx,
[[maybe_unused] ] uint32_t  kv_head_idx 
) const
inline

◆ LogitsTransform()

template<uint32_t VARIANT_CODE, bool UseExp2 = false>
template<typename Params , typename T >
__device__ __forceinline__ T ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::LogitsTransform ( const Params &  params,
logits,
[[maybe_unused] ] uint32_t  batch_idx,
[[maybe_unused] ] uint32_t  qo_head_idx,
[[maybe_unused] ] uint32_t  kv_head_idx 
) const
inline

NOTICE: For better performance, we simpliy transform thread buffer without calculating qo_idx/kv_idx.

◆ QueryTransform()

template<uint32_t VARIANT_CODE, bool UseExp2 = false>
template<typename Params , typename T >
__device__ __forceinline__ T ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::QueryTransform ( const Params &  params,
q 
) const
inline

Member Data Documentation

◆ use_exp2

template<uint32_t VARIANT_CODE, bool UseExp2 = false>
constexpr bool ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::use_exp2 = UseExp2
staticconstexpr

◆ use_logits_soft_cap

template<uint32_t VARIANT_CODE, bool UseExp2 = false>
constexpr bool ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0
staticconstexpr

The documentation for this struct was generated from the following file:
  • /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/variants.hpp