29 #ifndef _HIP_BFLOAT16_H_
30 #define _HIP_BFLOAT16_H_
32 #if __cplusplus < 201103L || !defined(__HIPCC__)
51 #include <type_traits>
66 :
data(float_to_bfloat16(f))
70 explicit __host__ __device__
hip_bfloat16(
float f, truncate_t)
71 :
data(truncate_float_to_bfloat16(f))
76 __host__ __device__
operator float()
const
82 } u = {uint32_t(
data) << 16};
86 static __host__ __device__
hip_bfloat16 round_to_bfloat16(
float f)
89 output.
data = float_to_bfloat16(f);
93 static __host__ __device__
hip_bfloat16 round_to_bfloat16(
float f, truncate_t)
96 output.
data = truncate_float_to_bfloat16(f);
101 static __host__ __device__ uint16_t float_to_bfloat16(
float f)
108 if(~u.int32 & 0x7f800000)
126 u.int32 += 0x7fff + ((u.int32 >> 16) & 1);
128 else if(u.int32 & 0xffff)
140 return uint16_t(u.int32 >> 16);
144 static __host__ __device__ uint16_t truncate_float_to_bfloat16(
float f)
151 return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
158 } hip_bfloat16_public;
160 static_assert(std::is_standard_layout<hip_bfloat16>{},
161 "hip_bfloat16 is not a standard layout type, and thus is "
162 "incompatible with C.");
164 static_assert(std::is_trivial<hip_bfloat16>{},
165 "hip_bfloat16 is not a trivial type, and thus is "
166 "incompatible with C.");
168 static_assert(
sizeof(
hip_bfloat16) ==
sizeof(hip_bfloat16_public)
169 && offsetof(
hip_bfloat16, data) == offsetof(hip_bfloat16_public, data),
170 "internal hip_bfloat16 does not match public hip_bfloat16");
172 inline std::ostream& operator<<(std::ostream& os,
const hip_bfloat16& bf16)
174 return os << float(bf16);
203 return float(a) < float(b);
207 return float(a) == float(b);
264 constexpr __host__ __device__
bool isinf(
hip_bfloat16 a)
266 return !(~a.
data & 0x7f80) && !(a.
data & 0x7f);
268 constexpr __host__ __device__
bool isnan(
hip_bfloat16 a)
270 return !(~a.
data & 0x7f80) && +(a.
data & 0x7f);
272 constexpr __host__ __device__
bool iszero(
hip_bfloat16 a)
274 return !(a.
data & 0x7fff);
Struct to represent a 16 bit brain floating point number.
Definition: hip_bfloat16.h:40
uint16_t data
Definition: hip_bfloat16.h:41