11 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0
12 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1
14 #ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT
15 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
18 #ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM
19 #define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0
24 __device__
inline float
27 #if(defined(__gfx90a__) || defined(__gfx94__)) && \
28 (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
29 CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
31 float result, numerator, denominator;
33 "v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n"
34 "v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n"
35 "v_rcp_f32_e32 %[denominator], %[denominator]\n"
36 "v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n"
37 "v_mul_f32_e32 %[result], %[numerator], %[denominator]"
38 : [numerator]
"=&v"(numerator), [denominator]
"=&v"(denominator), [result]
"=v"(result)
39 : [softmax_scale]
"s"(softmax_scale),
41 [logits_soft_cap_rcp]
"v"(logits_soft_cap_rcp));
44 return softmax_scale * logits * rcp<float>(1.f +
abs(logits * logits_soft_cap_rcp));
49 template <
typename ImplMask>
61 template <
typename ImplMask,
bool UseExp2 = false>
101 if constexpr(UseExp2)
110 float logits_soft_cap_,
111 float logits_soft_cap_rcp_)
119 if constexpr(UseExp2)
136 template <
typename Params,
typename T>
139 return type_convert<float>(q) * params.sm_scale;
144 template <
typename Params,
typename T>
147 [[maybe_unused]]
uint32_t batch_idx,
149 [[maybe_unused]]
uint32_t qo_head_idx,
150 [[maybe_unused]]
uint32_t kv_head_idx)
const
155 template <
typename Params>
156 __device__ __forceinline__
bool LogitsMask(
const Params& params,
157 [[maybe_unused]]
uint32_t batch_idx,
160 [[maybe_unused]]
uint32_t qo_head_idx,
161 [[maybe_unused]]
uint32_t kv_head_idx)
const
163 return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
167 template <
bool UseExp2 = false>
172 template <
typename Params,
typename T>
175 if constexpr(UseExp2)
181 return type_convert<float>(q) * params.sm_scale;
187 template <
typename Params,
typename T>
190 [[maybe_unused]]
uint32_t batch_idx,
192 [[maybe_unused]]
uint32_t qo_head_idx,
193 [[maybe_unused]]
uint32_t kv_head_idx)
const
195 if constexpr(UseExp2)
197 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
198 return params.logits_soft_cap *
200 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
202 params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
207 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
208 return params.logits_soft_cap *
209 tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
210 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
211 return type_convert<float>(logits) *
212 rcp<float>(1.f +
abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
217 template <
typename Params>
218 __device__ __forceinline__
bool LogitsMask(
const Params& params,
219 [[maybe_unused]]
uint32_t batch_idx,
222 [[maybe_unused]]
uint32_t qo_head_idx,
223 [[maybe_unused]]
uint32_t kv_head_idx)
const
225 return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
234 template <u
int32_t VARIANT_CODE,
bool UseExp2 = false>
243 template <
typename Params,
typename T>
250 return type_convert<float>(q) * params.sm_scale;
255 template <
typename Params,
typename T>
258 [[maybe_unused]]
uint32_t batch_idx,
260 [[maybe_unused]]
uint32_t qo_head_idx,
261 [[maybe_unused]]
uint32_t kv_head_idx)
const
265 if constexpr(UseExp2)
267 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
268 return params.logits_soft_cap *
270 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
272 params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
277 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
278 return params.logits_soft_cap *
279 tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
280 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
281 return type_convert<float>(logits) *
283 abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
290 template <
typename Params>
291 __device__ __forceinline__
bool LogitsMask(
const Params& params,
292 [[maybe_unused]]
uint32_t batch_idx,
295 [[maybe_unused]]
uint32_t qo_head_idx,
296 [[maybe_unused]]
uint32_t kv_head_idx)
const
298 return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
__device__ float exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
Definition: variants.hpp:25
Definition: cluster_descriptor.hpp:13
constexpr uint32_t ALIBI
Definition: variants.hpp:232
CK_TILE_DEVICE float tanh_fast< float >(float x)
Definition: math.hpp:1394
constexpr uint32_t LOGITS_SOFT_CAP
Definition: variants.hpp:231
constexpr uint32_t CUSTOM_MASK
Definition: variants.hpp:229
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:404
constexpr uint32_t SLIDING_WINDOW
Definition: variants.hpp:230
Definition: allocators.h:423
unsigned int uint32_t
Definition: stdint.h:126
Definition: variants.hpp:236
__device__ __forceinline__ bool LogitsMask(const Params ¶ms, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:291
__device__ __host__ ComposedAttention()=default
static constexpr bool use_exp2
Definition: variants.hpp:237
__device__ __forceinline__ T QueryTransform(const Params ¶ms, T q) const
Definition: variants.hpp:244
__device__ __forceinline__ T LogitsTransform(const Params ¶ms, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:256
static constexpr bool use_logits_soft_cap
Definition: variants.hpp:239
Definition: variants.hpp:169
__device__ __host__ LogitsSoftCap()=default
__device__ __forceinline__ T LogitsTransform(const Params ¶ms, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:188
__device__ __forceinline__ bool LogitsMask(const Params ¶ms, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:218
__device__ __forceinline__ T QueryTransform(const Params ¶ms, T q) const
Definition: variants.hpp:173
Definition: variants.hpp:63
float logits_soft_cap_rcp
Definition: variants.hpp:129
__host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition: variants.hpp:87
const ImplMask & impl_mask
Definition: variants.hpp:126
__device__ __host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_, float logits_soft_cap_rcp_)
Definition: variants.hpp:108
__device__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition: variants.hpp:65
float sm_scale
Definition: variants.hpp:127
float logits_soft_cap
Definition: variants.hpp:128
Definition: variants.hpp:133
__device__ __forceinline__ bool LogitsMask(const Params ¶ms, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:156
__device__ __forceinline__ T QueryTransform(const Params ¶ms, T q) const
Definition: variants.hpp:137
__device__ __host__ StandardAttention()=default
__device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params ¶ms, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:145
Definition: variants.hpp:51
const ImplMask & impl_mask
Definition: variants.hpp:57
__device__ __host__ StandardAttentionParams(const ImplMask &impl_mask_, float sm_scale_)
Definition: variants.hpp:52
float sm_scale
Definition: variants.hpp:58