/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/variants.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/variants.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/variants.hpp Source File
variants.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <type_traits>
7 
10 
11 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0
12 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1
13 
14 #ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT
15 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
16 #endif
17 
18 #ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM
19 #define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0
20 #endif
21 
22 namespace ck_tile {
23 namespace internal {
24 __device__ inline float
25 exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
26 {
27 #if(defined(__gfx90a__) || defined(__gfx94__)) && \
28  (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
29  CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
31  float result, numerator, denominator;
32  asm volatile(
33  "v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n"
34  "v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n"
35  "v_rcp_f32_e32 %[denominator], %[denominator]\n"
36  "v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n"
37  "v_mul_f32_e32 %[result], %[numerator], %[denominator]"
38  : [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result)
39  : [softmax_scale] "s"(softmax_scale),
40  [logits] "v"(logits),
41  [logits_soft_cap_rcp] "v"(logits_soft_cap_rcp));
42  return result;
43 #else
44  return softmax_scale * logits * rcp<float>(1.f + abs(logits * logits_soft_cap_rcp));
45 #endif
46 }
47 } // namespace internal
48 
49 template <typename ImplMask>
51 {
52  __device__ __host__ StandardAttentionParams(const ImplMask& impl_mask_, float sm_scale_)
53  : impl_mask(impl_mask_), sm_scale(sm_scale_)
54  {
55  }
56 
57  const ImplMask& impl_mask;
58  float sm_scale;
59 };
60 
61 template <typename ImplMask, bool UseExp2 = false>
63 {
64  __device__
65  LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_)
66  : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_)
67  {
68  if(0.f < logits_soft_cap)
69  {
70  logits_soft_cap_rcp = __builtin_amdgcn_rcpf(logits_soft_cap);
71  }
72  else
73  {
74  logits_soft_cap_rcp = 0.f;
75  }
76 
77  // move computation here to prevent compiler from generating inefficient instruction
78  // sequence
79  if constexpr(UseExp2)
80  {
81  logits_soft_cap = log2e_v<float> * logits_soft_cap;
82  logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
83  }
84  }
85 
86  __host__
87  LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_)
88  : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_)
89  {
90  if(0.f < logits_soft_cap)
91  {
93  }
94  else
95  {
96  logits_soft_cap_rcp = 0.f;
97  }
98 
99  // move computation here to prevent compiler from generating inefficient instruction
100  // sequence
101  if constexpr(UseExp2)
102  {
103  logits_soft_cap = log2e_v<float> * logits_soft_cap;
104  logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
105  }
106  }
107 
108  __device__ __host__ LogitsSoftCapParams(const ImplMask& impl_mask_,
109  float sm_scale_,
110  float logits_soft_cap_,
111  float logits_soft_cap_rcp_)
112  : impl_mask(impl_mask_),
113  sm_scale(sm_scale_),
114  logits_soft_cap(logits_soft_cap_),
115  logits_soft_cap_rcp(logits_soft_cap_rcp_)
116  {
117  // move computation here to prevent compiler from generating inefficient instruction
118  // sequence
119  if constexpr(UseExp2)
120  {
121  logits_soft_cap = log2e_v<float> * logits_soft_cap;
122  logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
123  }
124  }
125 
126  const ImplMask& impl_mask;
127  float sm_scale;
130 };
131 
133 {
134  __device__ __host__ StandardAttention() = default;
135 
136  template <typename Params, typename T>
137  __device__ __forceinline__ T QueryTransform(const Params& params, T q) const
138  {
139  return type_convert<float>(q) * params.sm_scale;
140  }
141 
144  template <typename Params, typename T>
145  __device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params& params,
146  T logits,
147  [[maybe_unused]] uint32_t batch_idx,
148  /*uint32_t qo_idx, uint32_t kv_idx,*/
149  [[maybe_unused]] uint32_t qo_head_idx,
150  [[maybe_unused]] uint32_t kv_head_idx) const
151  {
152  return logits;
153  }
154 
155  template <typename Params>
156  __device__ __forceinline__ bool LogitsMask(const Params& params,
157  [[maybe_unused]] uint32_t batch_idx,
158  uint32_t qo_idx,
159  uint32_t kv_idx,
160  [[maybe_unused]] uint32_t qo_head_idx,
161  [[maybe_unused]] uint32_t kv_head_idx) const
162  {
163  return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
164  }
165 
166  template <typename Params>
167  __device__ __forceinline__ bool LogitsSinkMask(const Params& params,
168  [[maybe_unused]] uint32_t batch_idx,
169  uint32_t qo_idx,
170  uint32_t kv_idx,
171  [[maybe_unused]] uint32_t qo_head_idx,
172  [[maybe_unused]] uint32_t kv_head_idx) const
173  {
174  return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
175  }
176 };
177 
178 template <bool UseExp2 = false>
180 {
181  __device__ __host__ LogitsSoftCap() = default;
182 
183  template <typename Params, typename T>
184  __device__ __forceinline__ T QueryTransform(const Params& params, T q) const
185  {
186  if constexpr(UseExp2)
187  {
188  return q;
189  }
190  else
191  {
192  return type_convert<float>(q) * params.sm_scale;
193  }
194  }
195 
198  template <typename Params, typename T>
199  __device__ __forceinline__ T LogitsTransform(const Params& params,
200  T logits,
201  [[maybe_unused]] uint32_t batch_idx,
202  /*uint32_t qo_idx, uint32_t kv_idx,*/
203  [[maybe_unused]] uint32_t qo_head_idx,
204  [[maybe_unused]] uint32_t kv_head_idx) const
205  {
206  if constexpr(UseExp2)
207  {
208 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
209  return params.logits_soft_cap *
210  tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
211 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
213  params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
214 #endif
215  }
216  else
217  {
218 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
219  return params.logits_soft_cap *
220  tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
221 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
222  return type_convert<float>(logits) *
223  rcp<float>(1.f + abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
224 #endif
225  }
226  }
227 
228  template <typename Params>
229  __device__ __forceinline__ bool LogitsMask(const Params& params,
230  [[maybe_unused]] uint32_t batch_idx,
231  uint32_t qo_idx,
232  uint32_t kv_idx,
233  [[maybe_unused]] uint32_t qo_head_idx,
234  [[maybe_unused]] uint32_t kv_head_idx) const
235  {
236  return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
237  }
238 
239  template <typename Params>
240  __device__ __forceinline__ bool LogitsSinkMask(const Params& params,
241  [[maybe_unused]] uint32_t batch_idx,
242  uint32_t qo_idx,
243  uint32_t kv_idx,
244  [[maybe_unused]] uint32_t qo_head_idx,
245  [[maybe_unused]] uint32_t kv_head_idx) const
246  {
247  return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
248  }
249 };
250 
251 constexpr uint32_t CUSTOM_MASK = 1U;
252 constexpr uint32_t SLIDING_WINDOW = 2U;
253 constexpr uint32_t LOGITS_SOFT_CAP = 4U;
254 constexpr uint32_t ALIBI = 8U;
255 
256 template <uint32_t VARIANT_CODE, bool UseExp2 = false>
258 {
259  static constexpr bool use_exp2 = UseExp2;
260 
261  static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0;
262 
263  __device__ __host__ ComposedAttention() = default;
264 
265  template <typename Params, typename T>
266  __device__ __forceinline__ T QueryTransform(const Params& params, T q) const
267  {
268  if constexpr(use_logits_soft_cap && UseExp2)
269  {
270  return q;
271  }
272  return type_convert<float>(q) * params.sm_scale;
273  }
274 
277  template <typename Params, typename T>
278  __device__ __forceinline__ T LogitsTransform(const Params& params,
279  T logits,
280  [[maybe_unused]] uint32_t batch_idx,
281  /*uint32_t qo_idx, uint32_t kv_idx,*/
282  [[maybe_unused]] uint32_t qo_head_idx,
283  [[maybe_unused]] uint32_t kv_head_idx) const
284  {
285  if constexpr(use_logits_soft_cap)
286  {
287  if constexpr(UseExp2)
288  {
289 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
290  return params.logits_soft_cap *
291  tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
292 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
294  params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
295 #endif
296  }
297  else
298  {
299 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
300  return params.logits_soft_cap *
301  tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
302 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
303  return type_convert<float>(logits) *
304  rcp<float>(1.f +
305  abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
306 #endif
307  }
308  }
309  return logits;
310  }
311 
312  template <typename Params>
313  __device__ __forceinline__ bool LogitsMask(const Params& params,
314  [[maybe_unused]] uint32_t batch_idx,
315  uint32_t qo_idx,
316  uint32_t kv_idx,
317  [[maybe_unused]] uint32_t qo_head_idx,
318  [[maybe_unused]] uint32_t kv_head_idx) const
319  {
320  return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
321  }
322 
323  template <typename Params>
324  __device__ __forceinline__ bool LogitsSinkMask(const Params& params,
325  [[maybe_unused]] uint32_t batch_idx,
326  uint32_t qo_idx,
327  uint32_t kv_idx,
328  [[maybe_unused]] uint32_t qo_head_idx,
329  [[maybe_unused]] uint32_t kv_head_idx) const
330  {
331  return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
332  }
333 };
334 
335 } // namespace ck_tile
__device__ float exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
Definition: variants.hpp:25
Definition: cluster_descriptor.hpp:13
constexpr uint32_t ALIBI
Definition: variants.hpp:254
CK_TILE_DEVICE float tanh_fast< float >(float x)
Definition: math.hpp:1387
constexpr uint32_t LOGITS_SOFT_CAP
Definition: variants.hpp:253
constexpr uint32_t CUSTOM_MASK
Definition: variants.hpp:251
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:403
constexpr uint32_t SLIDING_WINDOW
Definition: variants.hpp:252
Definition: allocators.h:459
unsigned int uint32_t
Definition: stdint.h:126
Definition: variants.hpp:258
__device__ __forceinline__ bool LogitsMask(const Params &params, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:313
__device__ __host__ ComposedAttention()=default
static constexpr bool use_exp2
Definition: variants.hpp:259
__device__ __forceinline__ bool LogitsSinkMask(const Params &params, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:324
__device__ __forceinline__ T QueryTransform(const Params &params, T q) const
Definition: variants.hpp:266
__device__ __forceinline__ T LogitsTransform(const Params &params, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:278
static constexpr bool use_logits_soft_cap
Definition: variants.hpp:261
Definition: variants.hpp:180
__device__ __host__ LogitsSoftCap()=default
__device__ __forceinline__ T LogitsTransform(const Params &params, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:199
__device__ __forceinline__ bool LogitsSinkMask(const Params &params, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:240
__device__ __forceinline__ bool LogitsMask(const Params &params, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:229
__device__ __forceinline__ T QueryTransform(const Params &params, T q) const
Definition: variants.hpp:184
Definition: variants.hpp:63
float logits_soft_cap_rcp
Definition: variants.hpp:129
__host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition: variants.hpp:87
const ImplMask & impl_mask
Definition: variants.hpp:126
__device__ __host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_, float logits_soft_cap_rcp_)
Definition: variants.hpp:108
__device__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition: variants.hpp:65
float sm_scale
Definition: variants.hpp:127
float logits_soft_cap
Definition: variants.hpp:128
Definition: variants.hpp:133
__device__ __forceinline__ bool LogitsMask(const Params &params, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:156
__device__ __forceinline__ T QueryTransform(const Params &params, T q) const
Definition: variants.hpp:137
__device__ __host__ StandardAttention()=default
__device__ __forceinline__ bool LogitsSinkMask(const Params &params, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:167
__device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params &params, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:145
Definition: variants.hpp:51
const ImplMask & impl_mask
Definition: variants.hpp:57
__device__ __host__ StandardAttentionParams(const ImplMask &impl_mask_, float sm_scale_)
Definition: variants.hpp:52
float sm_scale
Definition: variants.hpp:58