11 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0
12 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1
14 #ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT
15 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
18 #ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM
19 #define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0
24 __device__
inline float
27 #if(defined(__gfx90a__) || defined(__gfx94__)) && \
28 (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
29 CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
31 float result, numerator, denominator;
33 "v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n"
34 "v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n"
35 "v_rcp_f32_e32 %[denominator], %[denominator]\n"
36 "v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n"
37 "v_mul_f32_e32 %[result], %[numerator], %[denominator]"
38 : [numerator]
"=&v"(numerator), [denominator]
"=&v"(denominator), [result]
"=v"(result)
39 : [softmax_scale]
"s"(softmax_scale),
41 [logits_soft_cap_rcp]
"v"(logits_soft_cap_rcp));
44 return softmax_scale * logits * rcp<float>(1.f +
abs(logits * logits_soft_cap_rcp));
49 template <
typename ImplMask>
61 template <
typename ImplMask,
bool UseExp2 = false>
101 if constexpr(UseExp2)
110 float logits_soft_cap_,
111 float logits_soft_cap_rcp_)
119 if constexpr(UseExp2)
136 template <
typename Params,
typename T>
139 return type_convert<float>(q) * params.sm_scale;
144 template <
typename Params,
typename T>
147 [[maybe_unused]]
uint32_t batch_idx,
149 [[maybe_unused]]
uint32_t qo_head_idx,
150 [[maybe_unused]]
uint32_t kv_head_idx)
const
155 template <
typename Params>
156 __device__ __forceinline__
bool LogitsMask(
const Params& params,
157 [[maybe_unused]]
uint32_t batch_idx,
160 [[maybe_unused]]
uint32_t qo_head_idx,
161 [[maybe_unused]]
uint32_t kv_head_idx)
const
163 return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
166 template <
typename Params>
168 [[maybe_unused]]
uint32_t batch_idx,
171 [[maybe_unused]]
uint32_t qo_head_idx,
172 [[maybe_unused]]
uint32_t kv_head_idx)
const
174 return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
178 template <
bool UseExp2 = false>
183 template <
typename Params,
typename T>
186 if constexpr(UseExp2)
192 return type_convert<float>(q) * params.sm_scale;
198 template <
typename Params,
typename T>
201 [[maybe_unused]]
uint32_t batch_idx,
203 [[maybe_unused]]
uint32_t qo_head_idx,
204 [[maybe_unused]]
uint32_t kv_head_idx)
const
206 if constexpr(UseExp2)
208 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
209 return params.logits_soft_cap *
211 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
213 params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
218 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
219 return params.logits_soft_cap *
220 tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
221 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
222 return type_convert<float>(logits) *
223 rcp<float>(1.f +
abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
228 template <
typename Params>
229 __device__ __forceinline__
bool LogitsMask(
const Params& params,
230 [[maybe_unused]]
uint32_t batch_idx,
233 [[maybe_unused]]
uint32_t qo_head_idx,
234 [[maybe_unused]]
uint32_t kv_head_idx)
const
236 return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
239 template <
typename Params>
241 [[maybe_unused]]
uint32_t batch_idx,
244 [[maybe_unused]]
uint32_t qo_head_idx,
245 [[maybe_unused]]
uint32_t kv_head_idx)
const
247 return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
256 template <u
int32_t VARIANT_CODE,
bool UseExp2 = false>
265 template <
typename Params,
typename T>
272 return type_convert<float>(q) * params.sm_scale;
277 template <
typename Params,
typename T>
280 [[maybe_unused]]
uint32_t batch_idx,
282 [[maybe_unused]]
uint32_t qo_head_idx,
283 [[maybe_unused]]
uint32_t kv_head_idx)
const
287 if constexpr(UseExp2)
289 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
290 return params.logits_soft_cap *
292 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
294 params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
299 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
300 return params.logits_soft_cap *
301 tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
302 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
303 return type_convert<float>(logits) *
305 abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
312 template <
typename Params>
313 __device__ __forceinline__
bool LogitsMask(
const Params& params,
314 [[maybe_unused]]
uint32_t batch_idx,
317 [[maybe_unused]]
uint32_t qo_head_idx,
318 [[maybe_unused]]
uint32_t kv_head_idx)
const
320 return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
323 template <
typename Params>
325 [[maybe_unused]]
uint32_t batch_idx,
328 [[maybe_unused]]
uint32_t qo_head_idx,
329 [[maybe_unused]]
uint32_t kv_head_idx)
const
331 return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
__device__ float exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
Definition: variants.hpp:25
Definition: cluster_descriptor.hpp:13
constexpr uint32_t ALIBI
Definition: variants.hpp:254
CK_TILE_DEVICE float tanh_fast< float >(float x)
Definition: math.hpp:1387
constexpr uint32_t LOGITS_SOFT_CAP
Definition: variants.hpp:253
constexpr uint32_t CUSTOM_MASK
Definition: variants.hpp:251
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:403
constexpr uint32_t SLIDING_WINDOW
Definition: variants.hpp:252
Definition: allocators.h:459
unsigned int uint32_t
Definition: stdint.h:126
Definition: variants.hpp:258
__device__ __forceinline__ bool LogitsMask(const Params ¶ms, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:313
__device__ __host__ ComposedAttention()=default
static constexpr bool use_exp2
Definition: variants.hpp:259
__device__ __forceinline__ bool LogitsSinkMask(const Params ¶ms, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:324
__device__ __forceinline__ T QueryTransform(const Params ¶ms, T q) const
Definition: variants.hpp:266
__device__ __forceinline__ T LogitsTransform(const Params ¶ms, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:278
static constexpr bool use_logits_soft_cap
Definition: variants.hpp:261
Definition: variants.hpp:180
__device__ __host__ LogitsSoftCap()=default
__device__ __forceinline__ T LogitsTransform(const Params ¶ms, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:199
__device__ __forceinline__ bool LogitsSinkMask(const Params ¶ms, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:240
__device__ __forceinline__ bool LogitsMask(const Params ¶ms, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:229
__device__ __forceinline__ T QueryTransform(const Params ¶ms, T q) const
Definition: variants.hpp:184
Definition: variants.hpp:63
float logits_soft_cap_rcp
Definition: variants.hpp:129
__host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition: variants.hpp:87
const ImplMask & impl_mask
Definition: variants.hpp:126
__device__ __host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_, float logits_soft_cap_rcp_)
Definition: variants.hpp:108
__device__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition: variants.hpp:65
float sm_scale
Definition: variants.hpp:127
float logits_soft_cap
Definition: variants.hpp:128
Definition: variants.hpp:133
__device__ __forceinline__ bool LogitsMask(const Params ¶ms, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:156
__device__ __forceinline__ T QueryTransform(const Params ¶ms, T q) const
Definition: variants.hpp:137
__device__ __host__ StandardAttention()=default
__device__ __forceinline__ bool LogitsSinkMask(const Params ¶ms, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:167
__device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params ¶ms, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:145
Definition: variants.hpp:51
const ImplMask & impl_mask
Definition: variants.hpp:57
__device__ __host__ StandardAttentionParams(const ImplMask &impl_mask_, float sm_scale_)
Definition: variants.hpp:52
float sm_scale
Definition: variants.hpp:58