/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/bfloat16.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/bfloat16.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/bfloat16.hpp Source File
bfloat16.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
9 #if CK_TILE_USE_LLVM_BUILTIN_BF16
10 #include <hip/hip_bfloat16.h>
11 #endif
12 #include <stdint.h>
13 
14 #pragma once
15 
16 namespace ck_tile {
17 
19 {
20  standard = 0, // rtn
22  truncate,
24  rta_asm, // round to nearest away
25 };
26 
27 template <bf16_rounding_mode rounding =
29 CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
30 
31 template <bf16_rounding_mode rounding =
33 CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
34 
36 constexpr float bf16_to_float_raw(uint16_t x);
37 
39 constexpr double bf16_to_double_raw(uint16_t x);
40 
41 #if CK_TILE_USE_CUSTOM_DATA_TYPE
42 // HIP use __hip_bfloat16 as struct
43 struct alignas(2) bfloat16_t
44 {
45  using raw_type = uint16_t;
46  raw_type data;
47 
49  static constexpr bfloat16_t bit_cast(raw_type x)
50  {
51  bfloat16_t y;
52  y.data = x;
53  return y;
54  }
55 
56  // constructor
57  constexpr bfloat16_t() : data() {}
58 
59  // construct from float
61  explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
62 
63  // construct from double
65  explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
66 
67  // construct from int
69  explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
70 
71  // construct from unsigned int
73  explicit constexpr bfloat16_t(const unsigned int& x)
74  : data(float_to_bf16_raw(static_cast<float>(x)))
75  {
76  }
77 
78  // cast to float
80  explicit constexpr operator float() const { return bf16_to_float_raw(data); }
81 
82  // cast to float
84  explicit constexpr operator double() const { return bf16_to_double_raw(data); }
85 
86  // cast to int
88  explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
89 
90  // internal access
92  constexpr raw_type& get() { return data; }
93 
95  constexpr raw_type get() const { return data; }
96 };
97 template <typename>
98 struct native_t;
99 
100 template <>
101 struct native_t<bfloat16_t>
102 {
103  using type = ushort;
104 };
105 using bf16_t = bfloat16_t;
106 using bf16_raw_t = typename bf16_t::raw_type;
107 #else
108 #if CK_TILE_USE_LLVM_BUILTIN_BF16
109 using bfloat16_t = __bf16;
110 #else
111 using bfloat16_t = ushort;
112 #endif
115 #endif
116 // round to nearest
118 constexpr uint16_t float_to_bf16_rtn_raw(float f)
119 {
120  union
121  {
122  float fp32;
123  uint32_t int32;
124  } u = {f};
125  if(~u.int32 & 0x7f800000)
126  {
127  // When the exponent bits are not all 1s, then the value is zero, normal,
128  // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
129  // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
130  // This causes the bfloat16's mantissa to be incremented by 1 if the 16
131  // least significant bits of the float mantissa are greater than 0x8000,
132  // or if they are equal to 0x8000 and the least significant bit of the
133  // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
134  // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
135  // has the value 0x7f, then incrementing it causes it to become 0x00 and
136  // the exponent is incremented by one, which is the next higher FP value
137  // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
138  // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
139  // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
140  // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
141  // incrementing it causes it to become an exponent of 0xFF and a mantissa
142  // of 0x00, which is Inf, the next higher value to the unrounded value.
143  u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
144  }
145  else if(u.int32 & 0xffff)
146  {
147  // When all of the exponent bits are 1, the value is Inf or NaN.
148  // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
149  // mantissa bit. Quiet NaN is indicated by the most significant mantissa
150  // bit being 1. Signaling NaN is indicated by the most significant
151  // mantissa bit being 0 but some other bit(s) being 1. If any of the
152  // lower 16 bits of the mantissa are 1, we set the least significant bit
153  // of the bfloat16 mantissa, in order to preserve signaling NaN in case
154  // the bloat16's mantissa bits are all 0.
155  u.int32 |= 0x10000; // Preserve signaling NaN
156  }
157  return uint16_t(u.int32 >> 16);
158 }
159 
161 constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); }
162 
165 {
166  union
167  {
168  float fp32;
169  uint32_t int32;
170  } u = {f};
171 
172  static constexpr uint32_t FP32_NAN = 0x7fff0000;
173  static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
174 
175  using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
176  uint32x2_t check_nan;
177  uint32_t tmp;
178  asm volatile("\n \
179  v_cmp_u_f32 %0, %2, %2 \n \
180  v_bfe_u32 %1, %2, 16, 1 \n \
181  v_add3_u32 %1, %2, %1, %3 \n \
182  v_cndmask_b32 %2, %1, %4, %0 \n \
183  v_lshrrev_b32 %2, 16, %2 \n \
184  "
185  : "=s"(check_nan), "+v"(tmp), "+v"(u.fp32)
186  : "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN));
187 
188  return uint16_t(u.int32);
189 }
190 
191 // TODO: do we need this on host?
194 
197 {
198  union
199  {
200  float fp32;
201  struct
202  {
203  uint16_t lo;
204  uint16_t hi;
205  };
206  } u = {f};
207 
208  const uint32_t low_nan = 0x7fff;
209  const uint32_t hi_nan = 0x7fff0000;
210 
211  using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
212  uint32x2_t check_nan;
213 
214  asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
215  "v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
216  "v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
217  : [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
218  : [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
219 
220  // Note: in above code snipet, we use hi 16 bit
221  return u.hi;
222 }
223 
224 // Truncate instead of rounding, preserving SNaN
227 {
228  union
229  {
230  float fp32;
231  uint32_t int32;
232  } u = {f};
233  return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
234 }
235 
236 // Fast truncate instead of rounding, RTZ
239 {
240  union
241  {
242  float fp32;
243  uint32_t int32;
244  } u = {f};
245  return uint16_t(u.int32 >> 16);
246 }
247 
248 template <bf16_rounding_mode rounding>
250 {
251  if constexpr(rounding == bf16_rounding_mode::standard)
252  return float_to_bf16_rtn_raw(f);
253  else if constexpr(rounding == bf16_rounding_mode::standard_asm)
254  return float_to_bf16_rtn_asm(f);
255  else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
256  return float_to_bf16_truc_nan_raw(f);
257  else if constexpr(rounding == bf16_rounding_mode::rta_asm)
258  return float_to_bf16_rta_asm(f);
259  else
260  return float_to_bf16_truc_raw(f);
261 }
262 
263 template <bf16_rounding_mode rounding>
265 {
266  return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
267 }
268 
270 constexpr float bf16_to_float_raw(uint16_t x)
271 {
272  union
273  {
274  uint32_t int32;
275  float fp32;
276  } u = {uint32_t(x) << 16};
277  return u.fp32;
278 }
279 
281 constexpr double bf16_to_double_raw(uint16_t x)
282 {
283  return static_cast<double>(bf16_to_float_raw(x));
284 }
285 
286 template <bf16_rounding_mode rounding =
289 {
290 #if defined(__gfx950__)
291  return static_cast<bfloat16_t>(f);
292 #else
293  return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
294 #endif
295 }
296 
297 template <bf16_rounding_mode rounding =
300 {
301  return bit_cast<bfloat16_t>(double_to_bf16_raw(f, constant<rounding>{}));
302 }
303 
305 constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast<uint16_t>(x)); }
306 
308 constexpr double bf16_to_double(bfloat16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
309 
310 template <bf16_rounding_mode rounding =
313 {
314  return bit_cast<bfloat16_t>(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
315 }
316 
318 constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast<fp16_t>(static_cast<float>(x)); }
319 
320 template <class T>
321 struct numeric;
322 
323 template <>
325 {
326  // minimum finite value, or minimum positive normalized value for float
327  CK_TILE_HOST_DEVICE static constexpr bfloat16_t min()
328  {
329  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
330  }
331 
332  // minumum finite value
334  {
335  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
336  }
337 
338  // maximum finite value
339  CK_TILE_HOST_DEVICE static constexpr bfloat16_t max()
340  {
341  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f7f));
342  }
343 
344  // difference between 1.0 and next value representable by float
346  {
347  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x1000));
348  }
349 
350  // maximum rounding error
351  // maximum rounding error
352  // bin : f edcba 9876543210
353  // bits: s eeeeeeee mmmmmmm
354  // 0 01111110 0000000 (0.5)
355  //
357  {
358  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x3f00));
359  }
360 
361  // positive infinity value
363  {
364  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f80));
365  }
366 
367  // quiet NaN
369  {
370  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
371  }
372 
373  // signaling NaN
375  {
376  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
377  }
378 
379  // smallest positive subnormal value
381  {
382  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0001));
383  }
385  {
386  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0));
387  }
388 };
389 
390 template <>
392 {
393  static constexpr int exp = 8;
394  static constexpr int mant = 7;
395  static constexpr int PackedSize = 1;
396 };
397 
398 #if CK_TILE_USE_CUSTOM_DATA_TYPE
400 #endif
401 
402 // math
405 {
406  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(bit_cast<bf16_raw_t>(x) & 0x7fff));
407 }
408 
410 bool isnan(const bfloat16_t& x)
411 {
412  uint16_t xx = bit_cast<bf16_raw_t>(x);
413  return (xx & 0x7FFF) > 0x7C00;
414 }
415 
418 {
419  return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
420 };
421 
424 {
425  return static_cast<bfloat16_t>(__ocml_exp_f32(static_cast<float>(x)));
426 };
427 
429 bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
430 
432 bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
433 
434 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
Definition: config.hpp:72
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_truc_nan_raw(float f)
Definition: bfloat16.hpp:226
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:432
ushort bfloat16_t
Definition: bfloat16.hpp:111
uint32_t uint32x2_t
Definition: vector_type.hpp:152
constexpr CK_TILE_HOST_DEVICE Y bit_cast(const X &x)
Definition: bit_cast.hpp:11
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant< rounding >={})
Definition: bfloat16.hpp:264
constexpr CK_TILE_HOST_DEVICE float bf16_to_float_raw(uint16_t x)
Definition: bfloat16.hpp:270
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_truc_raw(float f)
Definition: bfloat16.hpp:238
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
CK_TILE_HOST uint16_t float_to_bf16_rta_asm(float f)
Definition: bfloat16.hpp:193
constexpr CK_TILE_HOST_DEVICE double bf16_to_double_raw(uint16_t x)
Definition: bfloat16.hpp:281
constexpr CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant< rounding >={})
Definition: bfloat16.hpp:288
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:417
constexpr CK_TILE_HOST_DEVICE float bf16_to_float(bfloat16_t x)
Definition: bfloat16.hpp:305
uint16_t bf16_raw_t
Definition: bfloat16.hpp:114
constexpr CK_TILE_HOST_DEVICE half_t bf16_to_fp16(bfloat16_t x)
Definition: bfloat16.hpp:318
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:423
bf16_rounding_mode
Definition: bfloat16.hpp:19
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:404
constexpr CK_TILE_HOST_DEVICE double bf16_to_double(bfloat16_t x)
Definition: bfloat16.hpp:308
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_rtn_raw(float f)
Definition: bfloat16.hpp:118
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant< rounding >={})
Definition: bfloat16.hpp:249
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition: bfloat16.hpp:410
constexpr CK_TILE_HOST uint16_t float_to_bf16_rtn_asm(float f)
Definition: bfloat16.hpp:161
CK_TILE_HOST_DEVICE constexpr bfloat16_t fp16_to_bf16(half_t f, constant< rounding >={})
Definition: bfloat16.hpp:312
constexpr CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant< rounding >={})
Definition: bfloat16.hpp:299
_Float16 half_t
Definition: half.hpp:111
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:429
unsigned short uint16_t
Definition: stdint.h:125
unsigned int uint32_t
Definition: stdint.h:126
Definition: integral_constant.hpp:13
Definition: bfloat16.hpp:325
static constexpr CK_TILE_HOST_DEVICE bfloat16_t round_error()
Definition: bfloat16.hpp:356
static constexpr CK_TILE_HOST_DEVICE bfloat16_t infinity()
Definition: bfloat16.hpp:362
static constexpr CK_TILE_HOST_DEVICE bfloat16_t max()
Definition: bfloat16.hpp:339
static constexpr CK_TILE_HOST_DEVICE bfloat16_t denorm_min()
Definition: bfloat16.hpp:380
static constexpr CK_TILE_HOST_DEVICE bfloat16_t min()
Definition: bfloat16.hpp:327
static constexpr CK_TILE_HOST_DEVICE bfloat16_t lowest()
Definition: bfloat16.hpp:333
static constexpr CK_TILE_HOST_DEVICE bfloat16_t epsilon()
Definition: bfloat16.hpp:345
static constexpr CK_TILE_HOST_DEVICE bfloat16_t quiet_NaN()
Definition: bfloat16.hpp:368
static constexpr CK_TILE_HOST_DEVICE bfloat16_t zero()
Definition: bfloat16.hpp:384
static constexpr CK_TILE_HOST_DEVICE bfloat16_t signaling_NaN()
Definition: bfloat16.hpp:374
Definition: bfloat16.hpp:392
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:106