/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/bfloat16.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/bfloat16.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/bfloat16.hpp Source File
bfloat16.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
9 #if CK_TILE_USE_LLVM_BUILTIN_BF16
10 #include <hip/hip_bfloat16.h>
11 #endif
12 #include <stdint.h>
13 
14 #pragma once
15 
16 namespace ck_tile {
17 
19 {
20  standard = 0, // rtn
22  truncate,
24  rta_asm, // round to nearest away
25 };
26 
27 template <bf16_rounding_mode rounding =
29 CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
30 
31 template <bf16_rounding_mode rounding =
33 CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
34 
36 constexpr float bf16_to_float_raw(uint16_t x);
37 
39 constexpr double bf16_to_double_raw(uint16_t x);
40 
41 #if CK_TILE_USE_CUSTOM_DATA_TYPE
42 // HIP use __hip_bfloat16 as struct
43 struct alignas(2) bfloat16_t
44 {
45  using raw_type = uint16_t;
46  raw_type data;
47 
49  static constexpr bfloat16_t bit_cast(raw_type x)
50  {
51  bfloat16_t y;
52  y.data = x;
53  return y;
54  }
55 
56  // constructor
57  constexpr bfloat16_t() : data() {}
58 
59  // construct from float
61  explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
62 
63  // construct from double
65  explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
66 
67  // construct from int
69  explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
70 
71  // construct from unsigned int
73  explicit constexpr bfloat16_t(const unsigned int& x)
74  : data(float_to_bf16_raw(static_cast<float>(x)))
75  {
76  }
77 
78  // cast to float
80  explicit constexpr operator float() const { return bf16_to_float_raw(data); }
81 
82  // cast to float
84  explicit constexpr operator double() const { return bf16_to_double_raw(data); }
85 
86  // cast to int
88  explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
89 
90  // internal access
92  constexpr raw_type& get() { return data; }
93 
95  constexpr raw_type get() const { return data; }
96 };
97 template <typename>
98 struct native_t;
99 
100 template <>
101 struct native_t<bfloat16_t>
102 {
103  using type = ushort;
104 };
105 using bf16_t = bfloat16_t;
106 using bf16_raw_t = typename bf16_t::raw_type;
107 #else
108 #if CK_TILE_USE_LLVM_BUILTIN_BF16
109 using bfloat16_t = __bf16;
110 #else
111 using bfloat16_t = ushort;
112 #endif
115 #endif
116 // round to nearest
118 constexpr uint16_t float_to_bf16_rtn_raw(float f)
119 {
120  uint32_t bits = bit_cast<uint32_t>(f);
121  if(~bits & 0x7f800000)
122  {
123  // When the exponent bits are not all 1s, then the value is zero, normal,
124  // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
125  // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
126  // This causes the bfloat16's mantissa to be incremented by 1 if the 16
127  // least significant bits of the float mantissa are greater than 0x8000,
128  // or if they are equal to 0x8000 and the least significant bit of the
129  // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
130  // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
131  // has the value 0x7f, then incrementing it causes it to become 0x00 and
132  // the exponent is incremented by one, which is the next higher FP value
133  // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
134  // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
135  // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
136  // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
137  // incrementing it causes it to become an exponent of 0xFF and a mantissa
138  // of 0x00, which is Inf, the next higher value to the unrounded value.
139  bits += 0x7fff + ((bits >> 16) & 1); // Round to nearest, round to even
140  }
141  else if(bits & 0xffff)
142  {
143  // When all of the exponent bits are 1, the value is Inf or NaN.
144  // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
145  // mantissa bit. Quiet NaN is indicated by the most significant mantissa
146  // bit being 1. Signaling NaN is indicated by the most significant
147  // mantissa bit being 0 but some other bit(s) being 1. If any of the
148  // lower 16 bits of the mantissa are 1, we set the least significant bit
149  // of the bfloat16 mantissa, in order to preserve signaling NaN in case
150  // the bloat16's mantissa bits are all 0.
151  bits |= 0x10000; // Preserve signaling NaN
152  }
153  return uint16_t(bits >> 16);
154 }
155 
157 constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); }
158 
161 {
162  union
163  {
164  float fp32;
165  uint32_t int32;
166  } u = {f};
167 
168  static constexpr uint32_t FP32_NAN = 0x7fff0000;
169  static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
170 
171 #if defined(__GFX9__)
172  using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
173  uint32x2_t check_nan;
174 #else
175  uint32_t check_nan;
176 #endif
177  uint32_t tmp;
178  asm volatile("\n \
179  v_cmp_u_f32 %0, %2, %2 \n \
180  v_bfe_u32 %1, %2, 16, 1 \n \
181  v_add3_u32 %1, %2, %1, %3 \n \
182  v_cndmask_b32 %2, %1, %4, %0 \n \
183  v_lshrrev_b32 %2, 16, %2 \n \
184  "
185  : "=s"(check_nan), "+v"(tmp), "+v"(u.fp32)
186  : "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN));
187 
188  return uint16_t(u.int32);
189 }
190 
191 // TODO: do we need this on host?
194 
197 {
198  union
199  {
200  float fp32;
201  struct
202  {
203  uint16_t lo;
204  uint16_t hi;
205  };
206  } u = {f};
207 
208  const uint32_t low_nan = 0x7fff;
209  const uint32_t hi_nan = 0x7fff0000;
210 
211 #if defined(__GFX9__)
212  using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
213  uint32x2_t check_nan;
214 #else
215  uint32_t check_nan;
216 #endif
217 
218  asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
219  "v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
220  "v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
221  : [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
222  : [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
223 
224  // Note: in above code snipet, we use hi 16 bit
225  return u.hi;
226 }
227 
228 // Truncate instead of rounding, preserving SNaN
231 {
232  uint32_t bits = bit_cast<uint32_t>(f);
233  return static_cast<uint16_t>(bits >> 16) | (!(~bits & 0x7f800000) && (bits & 0xffff));
234 }
235 
236 // Fast truncate instead of rounding, RTZ
239 {
240  uint32_t bits = bit_cast<uint32_t>(f);
241  return static_cast<uint16_t>(bits >> 16);
242 }
243 
244 template <bf16_rounding_mode rounding>
246 {
247  if constexpr(rounding == bf16_rounding_mode::standard)
248  return float_to_bf16_rtn_raw(f);
249  else if constexpr(rounding == bf16_rounding_mode::standard_asm)
250  return float_to_bf16_rtn_asm(f);
251  else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
252  return float_to_bf16_truc_nan_raw(f);
253  else if constexpr(rounding == bf16_rounding_mode::rta_asm)
254  return float_to_bf16_rta_asm(f);
255  else
256  return float_to_bf16_truc_raw(f);
257 }
258 
259 template <bf16_rounding_mode rounding>
261 {
262  return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
263 }
264 
266 constexpr float bf16_to_float_raw(uint16_t x)
267 {
268  union
269  {
270  uint32_t int32;
271  float fp32;
272  } u = {uint32_t(x) << 16};
273  return u.fp32;
274 }
275 
277 constexpr double bf16_to_double_raw(uint16_t x)
278 {
279  return static_cast<double>(bf16_to_float_raw(x));
280 }
281 
282 template <bf16_rounding_mode rounding =
285 {
286 #if CK_TILE_USE_LLVM_BUILTIN_BF16
287  return static_cast<bfloat16_t>(f);
288 #else
289  return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
290 #endif
291 }
292 
293 template <bf16_rounding_mode rounding =
296 {
297  return bit_cast<bfloat16_t>(double_to_bf16_raw(f, constant<rounding>{}));
298 }
299 
301 constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast<uint16_t>(x)); }
302 
304 constexpr double bf16_to_double(bfloat16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
305 
306 template <bf16_rounding_mode rounding =
309 {
310  return bit_cast<bfloat16_t>(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
311 }
312 
314 constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast<fp16_t>(static_cast<float>(x)); }
315 
316 template <class T>
317 struct numeric;
318 
319 template <>
321 {
322  // minimum finite value, or minimum positive normalized value for float
323  CK_TILE_HOST_DEVICE static constexpr bfloat16_t min()
324  {
325  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
326  }
327 
328  // minumum finite value
330  {
331  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
332  }
333 
334  // maximum finite value
335  CK_TILE_HOST_DEVICE static constexpr bfloat16_t max()
336  {
337  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f7f));
338  }
339 
340  // difference between 1.0 and next value representable by float
342  {
343  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x1000));
344  }
345 
346  // maximum rounding error
347  // maximum rounding error
348  // bin : f edcba 9876543210
349  // bits: s eeeeeeee mmmmmmm
350  // 0 01111110 0000000 (0.5)
351  //
353  {
354  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x3f00));
355  }
356 
357  // positive infinity value
359  {
360  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f80));
361  }
362 
363  // quiet NaN
365  {
366  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
367  }
368 
369  // signaling NaN
371  {
372  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
373  }
374 
375  // smallest positive subnormal value
377  {
378  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0001));
379  }
381  {
382  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0));
383  }
384 };
385 
386 template <>
388 {
389  static constexpr int exp = 8;
390  static constexpr int mant = 7;
391  static constexpr int PackedSize = 1;
392 };
393 
394 #if CK_TILE_USE_CUSTOM_DATA_TYPE
396 #endif
397 
398 // math
401 {
402  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(bit_cast<bf16_raw_t>(x) & 0x7fff));
403 }
404 
406 bool isnan(const bfloat16_t& x)
407 {
408  uint16_t xx = bit_cast<bf16_raw_t>(x);
409  return (xx & 0x7FFF) > 0x7C00;
410 }
411 
414 {
415  return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
416 };
417 
420 {
421  return static_cast<bfloat16_t>(__ocml_exp_f32(static_cast<float>(x)));
422 };
423 
425 bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
426 
428 bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
429 
430 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
Definition: config.hpp:72
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_truc_nan_raw(float f)
Definition: bfloat16.hpp:230
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:428
ushort bfloat16_t
Definition: bfloat16.hpp:111
uint32_t uint32x2_t
Definition: vector_type.hpp:163
constexpr CK_TILE_HOST_DEVICE Y bit_cast(const X &x)
Definition: bit_cast.hpp:11
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant< rounding >={})
Definition: bfloat16.hpp:260
constexpr CK_TILE_HOST_DEVICE float bf16_to_float_raw(uint16_t x)
Definition: bfloat16.hpp:266
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_truc_raw(float f)
Definition: bfloat16.hpp:238
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
CK_TILE_HOST uint16_t float_to_bf16_rta_asm(float f)
Definition: bfloat16.hpp:193
constexpr CK_TILE_HOST_DEVICE double bf16_to_double_raw(uint16_t x)
Definition: bfloat16.hpp:277
constexpr CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant< rounding >={})
Definition: bfloat16.hpp:284
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:413
constexpr CK_TILE_HOST_DEVICE float bf16_to_float(bfloat16_t x)
Definition: bfloat16.hpp:301
uint16_t bf16_raw_t
Definition: bfloat16.hpp:114
constexpr CK_TILE_HOST_DEVICE half_t bf16_to_fp16(bfloat16_t x)
Definition: bfloat16.hpp:314
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:419
bf16_rounding_mode
Definition: bfloat16.hpp:19
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:400
constexpr CK_TILE_HOST_DEVICE double bf16_to_double(bfloat16_t x)
Definition: bfloat16.hpp:304
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_rtn_raw(float f)
Definition: bfloat16.hpp:118
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant< rounding >={})
Definition: bfloat16.hpp:245
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition: bfloat16.hpp:406
constexpr CK_TILE_HOST uint16_t float_to_bf16_rtn_asm(float f)
Definition: bfloat16.hpp:157
CK_TILE_HOST_DEVICE constexpr bfloat16_t fp16_to_bf16(half_t f, constant< rounding >={})
Definition: bfloat16.hpp:308
constexpr CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant< rounding >={})
Definition: bfloat16.hpp:295
_Float16 half_t
Definition: half.hpp:111
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:425
unsigned short uint16_t
Definition: stdint.h:125
unsigned int uint32_t
Definition: stdint.h:126
Definition: integral_constant.hpp:13
Definition: bfloat16.hpp:321
static constexpr CK_TILE_HOST_DEVICE bfloat16_t round_error()
Definition: bfloat16.hpp:352
static constexpr CK_TILE_HOST_DEVICE bfloat16_t infinity()
Definition: bfloat16.hpp:358
static constexpr CK_TILE_HOST_DEVICE bfloat16_t max()
Definition: bfloat16.hpp:335
static constexpr CK_TILE_HOST_DEVICE bfloat16_t denorm_min()
Definition: bfloat16.hpp:376
static constexpr CK_TILE_HOST_DEVICE bfloat16_t min()
Definition: bfloat16.hpp:323
static constexpr CK_TILE_HOST_DEVICE bfloat16_t lowest()
Definition: bfloat16.hpp:329
static constexpr CK_TILE_HOST_DEVICE bfloat16_t epsilon()
Definition: bfloat16.hpp:341
static constexpr CK_TILE_HOST_DEVICE bfloat16_t quiet_NaN()
Definition: bfloat16.hpp:364
static constexpr CK_TILE_HOST_DEVICE bfloat16_t zero()
Definition: bfloat16.hpp:380
static constexpr CK_TILE_HOST_DEVICE bfloat16_t signaling_NaN()
Definition: bfloat16.hpp:370
Definition: bfloat16.hpp:388
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:106