15 using raw_type =
typename traits::bitwise_type;
22 return (x >> traits::mant) &
exp_mask;
30 return (x >> (
traits::exp + traits::mant)) == _numeric::binary_zero;
40 for(
raw_type i = 0; i < traits::mant; ++i)
42 mantissa += std::ldexp(
static_cast<float>(x & 0b1), -(traits::mant - i));
53 float sign = utils::is_positive(data) ? 1.0 : -1.0;
55 float mant = utils::get_mantissa(data);
57 return std::ldexp(sign * mant * scale,
exp);
72 uint32_t max_bitwise = bit_cast<uint32_t>(max_value);
80 bitwise_type mantissa =
92 float prev_val = bit_cast<float>(prev_bit);
93 float diff = max_value - prev_val;
95 float actual_max = max_value + (diff / 2);
97 if(std::abs(
value) < actual_max)
100 (exp << numeric_traits<T>::mant) | mantissa;
113 (exp << numeric_traits<T>::mant);
119 x = bit_cast<uint32_t>(
value);
137 const int mini_denormal_act_exponent = 1 - mini_bias;
139 int act_exponent, out_exponent, exponent_diff;
141 bool is_subnorm =
false;
145 act_exponent = exponent - bias + 1;
146 exponent_diff = mini_denormal_act_exponent - act_exponent;
151 act_exponent = exponent - bias;
152 if(act_exponent <= mini_denormal_act_exponent)
154 exponent_diff = mini_denormal_act_exponent - act_exponent;
161 mantissa += (1UL << mfmt);
165 shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
166 bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
170 if(is_subnorm && std::abs(
value) < std::abs(min_subnorm))
173 if(std::abs(
value) <= std::abs(min_subnorm -
value))
179 if(exponent_diff > 0)
180 mantissa >>= exponent_diff;
181 else if(exponent_diff == -1)
182 mantissa <<= -exponent_diff;
183 bool implicit_one = mantissa & (1 << mfmt);
184 out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
188 mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
190 if(out_exponent == 0)
192 if((1UL << mfmt) & mantissa)
199 if((1UL << (mfmt + 1)) & mantissa)
208 if(out_exponent == 0 && mantissa == 0)
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ T exp(T x)
Definition: math_v2.hpp:391
__host__ __device__ bool is_subnormal(T x)
Definition: mxfp_utils.hpp:45
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale=1.f)
Definition: mxfp_convert.hpp:50
CK_TILE_HOST_DEVICE T::raw_type convert_to_type(float value, float scale=1.f)
Definition: mxfp_convert.hpp:61
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:423
int32_t int32_t
Definition: integer.hpp:10
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned int uint32_t
Definition: stdint.h:126
Definition: numeric.hpp:81
Definition: mxfp_convert.hpp:11
static constexpr bool is_positive(raw_type x)
Definition: mxfp_convert.hpp:28
static constexpr raw_type get_exponent(raw_type x)
Definition: mxfp_convert.hpp:19
static constexpr double get_mantissa(raw_type x)
Definition: mxfp_convert.hpp:37
static constexpr int exp_mask
Definition: mxfp_convert.hpp:17
typename traits::bitwise_type raw_type
Definition: mxfp_convert.hpp:15
static constexpr bool is_subnormal(raw_type x)
Definition: mxfp_convert.hpp:32
static constexpr raw_type get_exponent(const T &x)
Definition: mxfp_convert.hpp:24
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26