/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hip/checkouts/docs-5.0.0/include/hip/hip_bfloat16.h Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hip/checkouts/docs-5.0.0/include/hip/hip_bfloat16.h Source File#

HIP Runtime API Reference: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hip/checkouts/docs-5.0.0/include/hip/hip_bfloat16.h Source File
hip_bfloat16.h
Go to the documentation of this file.
1 
29 #ifndef _HIP_BFLOAT16_H_
30 #define _HIP_BFLOAT16_H_
31 
32 #if __cplusplus < 201103L || !defined(__HIPCC__)
33 
34 // If this is a C compiler, C++ compiler below C++11, or a host-only compiler, we only
35 // include a minimal definition of hip_bfloat16
36 
37 #include <stdint.h>
39 typedef struct
40 {
41  uint16_t data;
42 } hip_bfloat16;
43 
44 #else // __cplusplus < 201103L || !defined(__HIPCC__)
45 
46 #include <cmath>
47 #include <cstddef>
48 #include <cstdint>
49 #include <hip/hip_runtime.h>
50 #include <ostream>
51 #include <type_traits>
52 
53 struct hip_bfloat16
54 {
55  uint16_t data;
56 
57  enum truncate_t
58  {
59  truncate
60  };
61 
62  __host__ __device__ hip_bfloat16() = default;
63 
64  // round upper 16 bits of IEEE float to convert to bfloat16
65  explicit __host__ __device__ hip_bfloat16(float f)
66  : data(float_to_bfloat16(f))
67  {
68  }
69 
70  explicit __host__ __device__ hip_bfloat16(float f, truncate_t)
71  : data(truncate_float_to_bfloat16(f))
72  {
73  }
74 
75  // zero extend lower 16 bits of bfloat16 to convert to IEEE float
76  __host__ __device__ operator float() const
77  {
78  union
79  {
80  uint32_t int32;
81  float fp32;
82  } u = {uint32_t(data) << 16};
83  return u.fp32;
84  }
85 
86  static __host__ __device__ hip_bfloat16 round_to_bfloat16(float f)
87  {
88  hip_bfloat16 output;
89  output.data = float_to_bfloat16(f);
90  return output;
91  }
92 
93  static __host__ __device__ hip_bfloat16 round_to_bfloat16(float f, truncate_t)
94  {
95  hip_bfloat16 output;
96  output.data = truncate_float_to_bfloat16(f);
97  return output;
98  }
99 
100 private:
101  static __host__ __device__ uint16_t float_to_bfloat16(float f)
102  {
103  union
104  {
105  float fp32;
106  uint32_t int32;
107  } u = {f};
108  if(~u.int32 & 0x7f800000)
109  {
110  // When the exponent bits are not all 1s, then the value is zero, normal,
111  // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
112  // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
113  // This causes the bfloat16's mantissa to be incremented by 1 if the 16
114  // least significant bits of the float mantissa are greater than 0x8000,
115  // or if they are equal to 0x8000 and the least significant bit of the
116  // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
117  // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
118  // has the value 0x7f, then incrementing it causes it to become 0x00 and
119  // the exponent is incremented by one, which is the next higher FP value
120  // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
121  // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
122  // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
123  // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
124  // incrementing it causes it to become an exponent of 0xFF and a mantissa
125  // of 0x00, which is Inf, the next higher value to the unrounded value.
126  u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
127  }
128  else if(u.int32 & 0xffff)
129  {
130  // When all of the exponent bits are 1, the value is Inf or NaN.
131  // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
132  // mantissa bit. Quiet NaN is indicated by the most significant mantissa
133  // bit being 1. Signaling NaN is indicated by the most significant
134  // mantissa bit being 0 but some other bit(s) being 1. If any of the
135  // lower 16 bits of the mantissa are 1, we set the least significant bit
136  // of the bfloat16 mantissa, in order to preserve signaling NaN in case
137  // the bloat16's mantissa bits are all 0.
138  u.int32 |= 0x10000; // Preserve signaling NaN
139  }
140  return uint16_t(u.int32 >> 16);
141  }
142 
143  // Truncate instead of rounding, preserving SNaN
144  static __host__ __device__ uint16_t truncate_float_to_bfloat16(float f)
145  {
146  union
147  {
148  float fp32;
149  uint32_t int32;
150  } u = {f};
151  return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
152  }
153 };
154 
155 typedef struct
156 {
157  uint16_t data;
158 } hip_bfloat16_public;
159 
160 static_assert(std::is_standard_layout<hip_bfloat16>{},
161  "hip_bfloat16 is not a standard layout type, and thus is "
162  "incompatible with C.");
163 
164 static_assert(std::is_trivial<hip_bfloat16>{},
165  "hip_bfloat16 is not a trivial type, and thus is "
166  "incompatible with C.");
167 
168 static_assert(sizeof(hip_bfloat16) == sizeof(hip_bfloat16_public)
169  && offsetof(hip_bfloat16, data) == offsetof(hip_bfloat16_public, data),
170  "internal hip_bfloat16 does not match public hip_bfloat16");
171 
172 inline std::ostream& operator<<(std::ostream& os, const hip_bfloat16& bf16)
173 {
174  return os << float(bf16);
175 }
176 inline __host__ __device__ hip_bfloat16 operator+(hip_bfloat16 a)
177 {
178  return a;
179 }
180 inline __host__ __device__ hip_bfloat16 operator-(hip_bfloat16 a)
181 {
182  a.data ^= 0x8000;
183  return a;
184 }
185 inline __host__ __device__ hip_bfloat16 operator+(hip_bfloat16 a, hip_bfloat16 b)
186 {
187  return hip_bfloat16(float(a) + float(b));
188 }
189 inline __host__ __device__ hip_bfloat16 operator-(hip_bfloat16 a, hip_bfloat16 b)
190 {
191  return hip_bfloat16(float(a) - float(b));
192 }
193 inline __host__ __device__ hip_bfloat16 operator*(hip_bfloat16 a, hip_bfloat16 b)
194 {
195  return hip_bfloat16(float(a) * float(b));
196 }
197 inline __host__ __device__ hip_bfloat16 operator/(hip_bfloat16 a, hip_bfloat16 b)
198 {
199  return hip_bfloat16(float(a) / float(b));
200 }
201 inline __host__ __device__ bool operator<(hip_bfloat16 a, hip_bfloat16 b)
202 {
203  return float(a) < float(b);
204 }
205 inline __host__ __device__ bool operator==(hip_bfloat16 a, hip_bfloat16 b)
206 {
207  return float(a) == float(b);
208 }
209 inline __host__ __device__ bool operator>(hip_bfloat16 a, hip_bfloat16 b)
210 {
211  return b < a;
212 }
213 inline __host__ __device__ bool operator<=(hip_bfloat16 a, hip_bfloat16 b)
214 {
215  return !(a > b);
216 }
217 inline __host__ __device__ bool operator!=(hip_bfloat16 a, hip_bfloat16 b)
218 {
219  return !(a == b);
220 }
221 inline __host__ __device__ bool operator>=(hip_bfloat16 a, hip_bfloat16 b)
222 {
223  return !(a < b);
224 }
225 inline __host__ __device__ hip_bfloat16& operator+=(hip_bfloat16& a, hip_bfloat16 b)
226 {
227  return a = a + b;
228 }
229 inline __host__ __device__ hip_bfloat16& operator-=(hip_bfloat16& a, hip_bfloat16 b)
230 {
231  return a = a - b;
232 }
233 inline __host__ __device__ hip_bfloat16& operator*=(hip_bfloat16& a, hip_bfloat16 b)
234 {
235  return a = a * b;
236 }
237 inline __host__ __device__ hip_bfloat16& operator/=(hip_bfloat16& a, hip_bfloat16 b)
238 {
239  return a = a / b;
240 }
241 inline __host__ __device__ hip_bfloat16& operator++(hip_bfloat16& a)
242 {
243  return a += hip_bfloat16(1.0f);
244 }
245 inline __host__ __device__ hip_bfloat16& operator--(hip_bfloat16& a)
246 {
247  return a -= hip_bfloat16(1.0f);
248 }
249 inline __host__ __device__ hip_bfloat16 operator++(hip_bfloat16& a, int)
250 {
251  hip_bfloat16 orig = a;
252  ++a;
253  return orig;
254 }
255 inline __host__ __device__ hip_bfloat16 operator--(hip_bfloat16& a, int)
256 {
257  hip_bfloat16 orig = a;
258  --a;
259  return orig;
260 }
261 
262 namespace std
263 {
264  constexpr __host__ __device__ bool isinf(hip_bfloat16 a)
265  {
266  return !(~a.data & 0x7f80) && !(a.data & 0x7f);
267  }
268  constexpr __host__ __device__ bool isnan(hip_bfloat16 a)
269  {
270  return !(~a.data & 0x7f80) && +(a.data & 0x7f);
271  }
272  constexpr __host__ __device__ bool iszero(hip_bfloat16 a)
273  {
274  return !(a.data & 0x7fff);
275  }
276 }
277 
278 #endif // __cplusplus < 201103L || !defined(__HIPCC__)
279 
280 #endif // _HIP_BFLOAT16_H_
Struct to represent a 16 bit brain floating point number.
Definition: hip_bfloat16.h:40
uint16_t data
Definition: hip_bfloat16.h:41