/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/variants.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/variants.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/variants.hpp Source File
variants.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <type_traits>
7 
10 
11 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0
12 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1
13 
14 #ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT
15 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
16 #endif
17 
18 #ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM
19 #define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0
20 #endif
21 
22 namespace ck_tile {
23 namespace internal {
24 __device__ inline float
25 exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
26 {
27 #if(defined(__gfx90a__) || defined(__gfx94__)) && \
28  (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
29  CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
31  float result, numerator, denominator;
32  asm volatile(
33  "v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n"
34  "v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n"
35  "v_rcp_f32_e32 %[denominator], %[denominator]\n"
36  "v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n"
37  "v_mul_f32_e32 %[result], %[numerator], %[denominator]"
38  : [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result)
39  : [softmax_scale] "s"(softmax_scale),
40  [logits] "v"(logits),
41  [logits_soft_cap_rcp] "v"(logits_soft_cap_rcp));
42  return result;
43 #else
44  return softmax_scale * logits * rcp<float>(1.f + abs(logits * logits_soft_cap_rcp));
45 #endif
46 }
47 } // namespace internal
48 
49 template <typename ImplMask>
51 {
52  __device__ __host__ StandardAttentionParams(const ImplMask& impl_mask_, float sm_scale_)
53  : impl_mask(impl_mask_), sm_scale(sm_scale_)
54  {
55  }
56 
57  const ImplMask& impl_mask;
58  float sm_scale;
59 };
60 
61 template <typename ImplMask, bool UseExp2 = false>
63 {
64  __device__
65  LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_)
66  : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_)
67  {
68  if(0.f < logits_soft_cap)
69  {
70  logits_soft_cap_rcp = __builtin_amdgcn_rcpf(logits_soft_cap);
71  }
72  else
73  {
74  logits_soft_cap_rcp = 0.f;
75  }
76 
77  // move computation here to prevent compiler from generating inefficient instruction
78  // sequence
79  if constexpr(UseExp2)
80  {
81  logits_soft_cap = log2e_v<float> * logits_soft_cap;
82  logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
83  }
84  }
85 
86  __host__
87  LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_)
88  : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_)
89  {
90  if(0.f < logits_soft_cap)
91  {
93  }
94  else
95  {
96  logits_soft_cap_rcp = 0.f;
97  }
98 
99  // move computation here to prevent compiler from generating inefficient instruction
100  // sequence
101  if constexpr(UseExp2)
102  {
103  logits_soft_cap = log2e_v<float> * logits_soft_cap;
104  logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
105  }
106  }
107 
108  __device__ __host__ LogitsSoftCapParams(const ImplMask& impl_mask_,
109  float sm_scale_,
110  float logits_soft_cap_,
111  float logits_soft_cap_rcp_)
112  : impl_mask(impl_mask_),
113  sm_scale(sm_scale_),
114  logits_soft_cap(logits_soft_cap_),
115  logits_soft_cap_rcp(logits_soft_cap_rcp_)
116  {
117  // move computation here to prevent compiler from generating inefficient instruction
118  // sequence
119  if constexpr(UseExp2)
120  {
121  logits_soft_cap = log2e_v<float> * logits_soft_cap;
122  logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
123  }
124  }
125 
126  const ImplMask& impl_mask;
127  float sm_scale;
130 };
131 
133 {
134  __device__ __host__ StandardAttention() = default;
135 
136  template <typename Params, typename T>
137  __device__ __forceinline__ T QueryTransform(const Params& params, T q) const
138  {
139  return type_convert<float>(q) * params.sm_scale;
140  }
141 
144  template <typename Params, typename T>
145  __device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params& params,
146  T logits,
147  [[maybe_unused]] uint32_t batch_idx,
148  /*uint32_t qo_idx, uint32_t kv_idx,*/
149  [[maybe_unused]] uint32_t qo_head_idx,
150  [[maybe_unused]] uint32_t kv_head_idx) const
151  {
152  return logits;
153  }
154 
155  template <typename Params>
156  __device__ __forceinline__ bool LogitsMask(const Params& params,
157  [[maybe_unused]] uint32_t batch_idx,
158  uint32_t qo_idx,
159  uint32_t kv_idx,
160  [[maybe_unused]] uint32_t qo_head_idx,
161  [[maybe_unused]] uint32_t kv_head_idx) const
162  {
163  return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
164  }
165 };
166 
167 template <bool UseExp2 = false>
169 {
170  __device__ __host__ LogitsSoftCap() = default;
171 
172  template <typename Params, typename T>
173  __device__ __forceinline__ T QueryTransform(const Params& params, T q) const
174  {
175  if constexpr(UseExp2)
176  {
177  return q;
178  }
179  else
180  {
181  return type_convert<float>(q) * params.sm_scale;
182  }
183  }
184 
187  template <typename Params, typename T>
188  __device__ __forceinline__ T LogitsTransform(const Params& params,
189  T logits,
190  [[maybe_unused]] uint32_t batch_idx,
191  /*uint32_t qo_idx, uint32_t kv_idx,*/
192  [[maybe_unused]] uint32_t qo_head_idx,
193  [[maybe_unused]] uint32_t kv_head_idx) const
194  {
195  if constexpr(UseExp2)
196  {
197 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
198  return params.logits_soft_cap *
199  tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
200 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
202  params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
203 #endif
204  }
205  else
206  {
207 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
208  return params.logits_soft_cap *
209  tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
210 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
211  return type_convert<float>(logits) *
212  rcp<float>(1.f + abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
213 #endif
214  }
215  }
216 
217  template <typename Params>
218  __device__ __forceinline__ bool LogitsMask(const Params& params,
219  [[maybe_unused]] uint32_t batch_idx,
220  uint32_t qo_idx,
221  uint32_t kv_idx,
222  [[maybe_unused]] uint32_t qo_head_idx,
223  [[maybe_unused]] uint32_t kv_head_idx) const
224  {
225  return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
226  }
227 };
228 
229 constexpr uint32_t CUSTOM_MASK = 1U;
230 constexpr uint32_t SLIDING_WINDOW = 2U;
231 constexpr uint32_t LOGITS_SOFT_CAP = 4U;
232 constexpr uint32_t ALIBI = 8U;
233 
234 template <uint32_t VARIANT_CODE, bool UseExp2 = false>
236 {
237  static constexpr bool use_exp2 = UseExp2;
238 
239  static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0;
240 
241  __device__ __host__ ComposedAttention() = default;
242 
243  template <typename Params, typename T>
244  __device__ __forceinline__ T QueryTransform(const Params& params, T q) const
245  {
246  if constexpr(use_logits_soft_cap && UseExp2)
247  {
248  return q;
249  }
250  return type_convert<float>(q) * params.sm_scale;
251  }
252 
255  template <typename Params, typename T>
256  __device__ __forceinline__ T LogitsTransform(const Params& params,
257  T logits,
258  [[maybe_unused]] uint32_t batch_idx,
259  /*uint32_t qo_idx, uint32_t kv_idx,*/
260  [[maybe_unused]] uint32_t qo_head_idx,
261  [[maybe_unused]] uint32_t kv_head_idx) const
262  {
263  if constexpr(use_logits_soft_cap)
264  {
265  if constexpr(UseExp2)
266  {
267 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
268  return params.logits_soft_cap *
269  tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
270 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
272  params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
273 #endif
274  }
275  else
276  {
277 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
278  return params.logits_soft_cap *
279  tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
280 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
281  return type_convert<float>(logits) *
282  rcp<float>(1.f +
283  abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
284 #endif
285  }
286  }
287  return logits;
288  }
289 
290  template <typename Params>
291  __device__ __forceinline__ bool LogitsMask(const Params& params,
292  [[maybe_unused]] uint32_t batch_idx,
293  uint32_t qo_idx,
294  uint32_t kv_idx,
295  [[maybe_unused]] uint32_t qo_head_idx,
296  [[maybe_unused]] uint32_t kv_head_idx) const
297  {
298  return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
299  }
300 };
301 
302 } // namespace ck_tile
__device__ float exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
Definition: variants.hpp:25
Definition: cluster_descriptor.hpp:13
constexpr uint32_t ALIBI
Definition: variants.hpp:232
CK_TILE_DEVICE float tanh_fast< float >(float x)
Definition: math.hpp:1394
constexpr uint32_t LOGITS_SOFT_CAP
Definition: variants.hpp:231
constexpr uint32_t CUSTOM_MASK
Definition: variants.hpp:229
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:404
constexpr uint32_t SLIDING_WINDOW
Definition: variants.hpp:230
Definition: allocators.h:423
unsigned int uint32_t
Definition: stdint.h:126
Definition: variants.hpp:236
__device__ __forceinline__ bool LogitsMask(const Params &params, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:291
__device__ __host__ ComposedAttention()=default
static constexpr bool use_exp2
Definition: variants.hpp:237
__device__ __forceinline__ T QueryTransform(const Params &params, T q) const
Definition: variants.hpp:244
__device__ __forceinline__ T LogitsTransform(const Params &params, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:256
static constexpr bool use_logits_soft_cap
Definition: variants.hpp:239
Definition: variants.hpp:169
__device__ __host__ LogitsSoftCap()=default
__device__ __forceinline__ T LogitsTransform(const Params &params, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:188
__device__ __forceinline__ bool LogitsMask(const Params &params, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:218
__device__ __forceinline__ T QueryTransform(const Params &params, T q) const
Definition: variants.hpp:173
Definition: variants.hpp:63
float logits_soft_cap_rcp
Definition: variants.hpp:129
__host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition: variants.hpp:87
const ImplMask & impl_mask
Definition: variants.hpp:126
__device__ __host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_, float logits_soft_cap_rcp_)
Definition: variants.hpp:108
__device__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition: variants.hpp:65
float sm_scale
Definition: variants.hpp:127
float logits_soft_cap
Definition: variants.hpp:128
Definition: variants.hpp:133
__device__ __forceinline__ bool LogitsMask(const Params &params, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:156
__device__ __forceinline__ T QueryTransform(const Params &params, T q) const
Definition: variants.hpp:137
__device__ __host__ StandardAttention()=default
__device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params &params, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:145
Definition: variants.hpp:51
const ImplMask & impl_mask
Definition: variants.hpp:57
__device__ __host__ StandardAttentionParams(const ImplMask &impl_mask_, float sm_scale_)
Definition: variants.hpp:52
float sm_scale
Definition: variants.hpp:58