/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/bfloat16.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/bfloat16.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/bfloat16.hpp Source File
bfloat16.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
9 #if CK_TILE_USE_LLVM_BUILTIN_BF16
10 #include <hip/hip_bfloat16.h>
11 #endif
12 #include <stdint.h>
13 
14 #pragma once
15 
16 namespace ck_tile {
17 
19 {
20  standard = 0, // rtn
22  truncate,
24  rta_asm, // round to nearest away
25 };
26 
27 template <bf16_rounding_mode rounding =
29 CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
30 
31 template <bf16_rounding_mode rounding =
33 CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
34 
36 constexpr float bf16_to_float_raw(uint16_t x);
37 
39 constexpr double bf16_to_double_raw(uint16_t x);
40 
41 #if CK_TILE_USE_CUSTOM_DATA_TYPE
42 // HIP use __hip_bfloat16 as struct
43 struct alignas(2) bfloat16_t
44 {
45  using raw_type = uint16_t;
46  raw_type data;
47 
49  static constexpr bfloat16_t bit_cast(raw_type x)
50  {
51  bfloat16_t y;
52  y.data = x;
53  return y;
54  }
55 
56  // constructor
57  constexpr bfloat16_t() : data() {}
58 
59  // construct from float
61  explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
62 
63  // construct from double
65  explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
66 
67  // construct from int
69  explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
70 
71  // construct from unsigned int
73  explicit constexpr bfloat16_t(const unsigned int& x)
74  : data(float_to_bf16_raw(static_cast<float>(x)))
75  {
76  }
77 
78  // cast to float
80  explicit constexpr operator float() const { return bf16_to_float_raw(data); }
81 
82  // cast to float
84  explicit constexpr operator double() const { return bf16_to_double_raw(data); }
85 
86  // cast to int
88  explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
89 
90  // internal access
92  constexpr raw_type& get() { return data; }
93 
95  constexpr raw_type get() const { return data; }
96 };
97 template <typename>
98 struct native_t;
99 
100 template <>
101 struct native_t<bfloat16_t>
102 {
103  using type = ushort;
104 };
105 using bf16_t = bfloat16_t;
106 using bf16_raw_t = typename bf16_t::raw_type;
107 #else
108 #if CK_TILE_USE_LLVM_BUILTIN_BF16
109 using bfloat16_t = __bf16;
110 #else
111 using bfloat16_t = ushort;
112 #endif
115 #endif
116 // round to nearest
118 constexpr uint16_t float_to_bf16_rtn_raw(float f)
119 {
120  uint32_t bits = bit_cast<uint32_t>(f);
121  if(~bits & 0x7f800000)
122  {
123  // When the exponent bits are not all 1s, then the value is zero, normal,
124  // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
125  // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
126  // This causes the bfloat16's mantissa to be incremented by 1 if the 16
127  // least significant bits of the float mantissa are greater than 0x8000,
128  // or if they are equal to 0x8000 and the least significant bit of the
129  // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
130  // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
131  // has the value 0x7f, then incrementing it causes it to become 0x00 and
132  // the exponent is incremented by one, which is the next higher FP value
133  // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
134  // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
135  // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
136  // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
137  // incrementing it causes it to become an exponent of 0xFF and a mantissa
138  // of 0x00, which is Inf, the next higher value to the unrounded value.
139  bits += 0x7fff + ((bits >> 16) & 1); // Round to nearest, round to even
140  }
141  else if(bits & 0xffff)
142  {
143  // When all of the exponent bits are 1, the value is Inf or NaN.
144  // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
145  // mantissa bit. Quiet NaN is indicated by the most significant mantissa
146  // bit being 1. Signaling NaN is indicated by the most significant
147  // mantissa bit being 0 but some other bit(s) being 1. If any of the
148  // lower 16 bits of the mantissa are 1, we set the least significant bit
149  // of the bfloat16 mantissa, in order to preserve signaling NaN in case
150  // the bloat16's mantissa bits are all 0.
151  bits |= 0x10000; // Preserve signaling NaN
152  }
153  return uint16_t(bits >> 16);
154 }
155 
157 constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); }
158 
161 {
162  union
163  {
164  float fp32;
165  uint32_t int32;
166  } u = {f};
167 
168  static constexpr uint32_t FP32_NAN = 0x7fff0000;
169  static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
170 
171  using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
172  uint32x2_t check_nan;
173  uint32_t tmp;
174  asm volatile("\n \
175  v_cmp_u_f32 %0, %2, %2 \n \
176  v_bfe_u32 %1, %2, 16, 1 \n \
177  v_add3_u32 %1, %2, %1, %3 \n \
178  v_cndmask_b32 %2, %1, %4, %0 \n \
179  v_lshrrev_b32 %2, 16, %2 \n \
180  "
181  : "=s"(check_nan), "+v"(tmp), "+v"(u.fp32)
182  : "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN));
183 
184  return uint16_t(u.int32);
185 }
186 
187 // TODO: do we need this on host?
190 
193 {
194  union
195  {
196  float fp32;
197  struct
198  {
199  uint16_t lo;
200  uint16_t hi;
201  };
202  } u = {f};
203 
204  const uint32_t low_nan = 0x7fff;
205  const uint32_t hi_nan = 0x7fff0000;
206 
207  using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
208  uint32x2_t check_nan;
209 
210  asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
211  "v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
212  "v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
213  : [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
214  : [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
215 
216  // Note: in above code snipet, we use hi 16 bit
217  return u.hi;
218 }
219 
220 // Truncate instead of rounding, preserving SNaN
223 {
224  uint32_t bits = bit_cast<uint32_t>(f);
225  return static_cast<uint16_t>(bits >> 16) | (!(~bits & 0x7f800000) && (bits & 0xffff));
226 }
227 
228 // Fast truncate instead of rounding, RTZ
231 {
232  uint32_t bits = bit_cast<uint32_t>(f);
233  return static_cast<uint16_t>(bits >> 16);
234 }
235 
236 template <bf16_rounding_mode rounding>
238 {
239  if constexpr(rounding == bf16_rounding_mode::standard)
240  return float_to_bf16_rtn_raw(f);
241  else if constexpr(rounding == bf16_rounding_mode::standard_asm)
242  return float_to_bf16_rtn_asm(f);
243  else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
244  return float_to_bf16_truc_nan_raw(f);
245  else if constexpr(rounding == bf16_rounding_mode::rta_asm)
246  return float_to_bf16_rta_asm(f);
247  else
248  return float_to_bf16_truc_raw(f);
249 }
250 
251 template <bf16_rounding_mode rounding>
253 {
254  return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
255 }
256 
258 constexpr float bf16_to_float_raw(uint16_t x)
259 {
260  union
261  {
262  uint32_t int32;
263  float fp32;
264  } u = {uint32_t(x) << 16};
265  return u.fp32;
266 }
267 
269 constexpr double bf16_to_double_raw(uint16_t x)
270 {
271  return static_cast<double>(bf16_to_float_raw(x));
272 }
273 
274 template <bf16_rounding_mode rounding =
277 {
278 #if CK_TILE_USE_LLVM_BUILTIN_BF16
279  return static_cast<bfloat16_t>(f);
280 #else
281  return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
282 #endif
283 }
284 
285 template <bf16_rounding_mode rounding =
288 {
289  return bit_cast<bfloat16_t>(double_to_bf16_raw(f, constant<rounding>{}));
290 }
291 
293 constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast<uint16_t>(x)); }
294 
296 constexpr double bf16_to_double(bfloat16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
297 
298 template <bf16_rounding_mode rounding =
301 {
302  return bit_cast<bfloat16_t>(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
303 }
304 
306 constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast<fp16_t>(static_cast<float>(x)); }
307 
308 template <class T>
309 struct numeric;
310 
311 template <>
313 {
314  // minimum finite value, or minimum positive normalized value for float
315  CK_TILE_HOST_DEVICE static constexpr bfloat16_t min()
316  {
317  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
318  }
319 
320  // minumum finite value
322  {
323  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
324  }
325 
326  // maximum finite value
327  CK_TILE_HOST_DEVICE static constexpr bfloat16_t max()
328  {
329  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f7f));
330  }
331 
332  // difference between 1.0 and next value representable by float
334  {
335  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x1000));
336  }
337 
338  // maximum rounding error
339  // maximum rounding error
340  // bin : f edcba 9876543210
341  // bits: s eeeeeeee mmmmmmm
342  // 0 01111110 0000000 (0.5)
343  //
345  {
346  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x3f00));
347  }
348 
349  // positive infinity value
351  {
352  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f80));
353  }
354 
355  // quiet NaN
357  {
358  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
359  }
360 
361  // signaling NaN
363  {
364  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
365  }
366 
367  // smallest positive subnormal value
369  {
370  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0001));
371  }
373  {
374  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0));
375  }
376 };
377 
378 template <>
380 {
381  static constexpr int exp = 8;
382  static constexpr int mant = 7;
383  static constexpr int PackedSize = 1;
384 };
385 
386 #if CK_TILE_USE_CUSTOM_DATA_TYPE
388 #endif
389 
390 // math
393 {
394  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(bit_cast<bf16_raw_t>(x) & 0x7fff));
395 }
396 
398 bool isnan(const bfloat16_t& x)
399 {
400  uint16_t xx = bit_cast<bf16_raw_t>(x);
401  return (xx & 0x7FFF) > 0x7C00;
402 }
403 
406 {
407  return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
408 };
409 
412 {
413  return static_cast<bfloat16_t>(__ocml_exp_f32(static_cast<float>(x)));
414 };
415 
417 bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
418 
420 bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
421 
422 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
Definition: config.hpp:72
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_truc_nan_raw(float f)
Definition: bfloat16.hpp:222
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:420
ushort bfloat16_t
Definition: bfloat16.hpp:111
uint32_t uint32x2_t
Definition: vector_type.hpp:152
constexpr CK_TILE_HOST_DEVICE Y bit_cast(const X &x)
Definition: bit_cast.hpp:11
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant< rounding >={})
Definition: bfloat16.hpp:252
constexpr CK_TILE_HOST_DEVICE float bf16_to_float_raw(uint16_t x)
Definition: bfloat16.hpp:258
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_truc_raw(float f)
Definition: bfloat16.hpp:230
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
CK_TILE_HOST uint16_t float_to_bf16_rta_asm(float f)
Definition: bfloat16.hpp:189
constexpr CK_TILE_HOST_DEVICE double bf16_to_double_raw(uint16_t x)
Definition: bfloat16.hpp:269
constexpr CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant< rounding >={})
Definition: bfloat16.hpp:276
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:405
constexpr CK_TILE_HOST_DEVICE float bf16_to_float(bfloat16_t x)
Definition: bfloat16.hpp:293
uint16_t bf16_raw_t
Definition: bfloat16.hpp:114
constexpr CK_TILE_HOST_DEVICE half_t bf16_to_fp16(bfloat16_t x)
Definition: bfloat16.hpp:306
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:411
bf16_rounding_mode
Definition: bfloat16.hpp:19
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:392
constexpr CK_TILE_HOST_DEVICE double bf16_to_double(bfloat16_t x)
Definition: bfloat16.hpp:296
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_rtn_raw(float f)
Definition: bfloat16.hpp:118
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant< rounding >={})
Definition: bfloat16.hpp:237
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition: bfloat16.hpp:398
constexpr CK_TILE_HOST uint16_t float_to_bf16_rtn_asm(float f)
Definition: bfloat16.hpp:157
CK_TILE_HOST_DEVICE constexpr bfloat16_t fp16_to_bf16(half_t f, constant< rounding >={})
Definition: bfloat16.hpp:300
constexpr CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant< rounding >={})
Definition: bfloat16.hpp:287
_Float16 half_t
Definition: half.hpp:111
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:417
unsigned short uint16_t
Definition: stdint.h:125
unsigned int uint32_t
Definition: stdint.h:126
Definition: integral_constant.hpp:13
Definition: bfloat16.hpp:313
static constexpr CK_TILE_HOST_DEVICE bfloat16_t round_error()
Definition: bfloat16.hpp:344
static constexpr CK_TILE_HOST_DEVICE bfloat16_t infinity()
Definition: bfloat16.hpp:350
static constexpr CK_TILE_HOST_DEVICE bfloat16_t max()
Definition: bfloat16.hpp:327
static constexpr CK_TILE_HOST_DEVICE bfloat16_t denorm_min()
Definition: bfloat16.hpp:368
static constexpr CK_TILE_HOST_DEVICE bfloat16_t min()
Definition: bfloat16.hpp:315
static constexpr CK_TILE_HOST_DEVICE bfloat16_t lowest()
Definition: bfloat16.hpp:321
static constexpr CK_TILE_HOST_DEVICE bfloat16_t epsilon()
Definition: bfloat16.hpp:333
static constexpr CK_TILE_HOST_DEVICE bfloat16_t quiet_NaN()
Definition: bfloat16.hpp:356
static constexpr CK_TILE_HOST_DEVICE bfloat16_t zero()
Definition: bfloat16.hpp:372
static constexpr CK_TILE_HOST_DEVICE bfloat16_t signaling_NaN()
Definition: bfloat16.hpp:362
Definition: bfloat16.hpp:380
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:106