#include <variants.hpp>
|
__device__ __host__ | ComposedAttention ()=default |
|
template<typename Params , typename T > |
__device__ __forceinline__ T | QueryTransform (const Params ¶ms, T q) const |
|
template<typename Params , typename T > |
__device__ __forceinline__ T | LogitsTransform (const Params ¶ms, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const |
|
template<typename Params > |
__device__ __forceinline__ bool | LogitsMask (const Params ¶ms, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const |
|
◆ ComposedAttention()
template<uint32_t VARIANT_CODE, bool UseExp2 = false>
◆ LogitsMask()
template<uint32_t VARIANT_CODE, bool UseExp2 = false>
template<typename Params >
◆ LogitsTransform()
template<uint32_t VARIANT_CODE, bool UseExp2 = false>
template<typename Params , typename T >
__device__ __forceinline__ T ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::LogitsTransform |
( |
const Params & |
params, |
|
|
T |
logits, |
|
|
[[maybe_unused] ] uint32_t |
batch_idx, |
|
|
[[maybe_unused] ] uint32_t |
qo_head_idx, |
|
|
[[maybe_unused] ] uint32_t |
kv_head_idx |
|
) |
| const |
|
inline |
NOTICE: For better performance, we simpliy transform thread buffer without calculating qo_idx/kv_idx.
◆ QueryTransform()
template<uint32_t VARIANT_CODE, bool UseExp2 = false>
template<typename Params , typename T >
__device__ __forceinline__ T ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::QueryTransform |
( |
const Params & |
params, |
|
|
T |
q |
|
) |
| const |
|
inline |
◆ use_exp2
template<uint32_t VARIANT_CODE, bool UseExp2 = false>
◆ use_logits_soft_cap
template<uint32_t VARIANT_CODE, bool UseExp2 = false>
The documentation for this struct was generated from the following file:
- /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/variants.hpp