19 __host__
inline int clz(
uint32_t x) {
return __builtin_clz(x); }
20 __device__
inline int clz(
uint32_t x) {
return __clz(x); }
28 template <
typename X,
typename Y,
bool negative_zero_nan,
bool clip,
bool stoch>
33 constexpr
int out_mant = NumericUtils<Y>::mant;
37 constexpr
int in_mant = NumericUtils<X>::mant;
42 constexpr Y nan_code = 0x80;
43 constexpr
uint32_t nan_mask = NumericUtils<X>::nan_mask;
46 using T_bitwise =
typename NumericUtils<X>::bitwise_type;
47 T_bitwise x_bitwise = bit_cast<T_bitwise>(x);
50 head = x_bitwise & NumericUtils<X>::head_mask;
51 mantissa = x_bitwise & NumericUtils<X>::mant_mask;
52 exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
53 sign = head >> (in_exp + in_mant);
54 bias = NumericUtils<X>::bias;
56 uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
57 uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
58 constexpr
int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
60 if constexpr(negative_zero_nan)
62 if((x_bitwise & nan_mask) == nan_mask)
67 if((x_bitwise & nan_mask) == nan_mask)
68 return signed_inf + (mantissa != 0 ? 1 : 0);
82 const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0);
83 const int out_denormal_act_exponent = 1 - out_bias;
88 int act_exponent, out_exponent, exponent_diff;
98 act_exponent = exponent - bias + 1;
99 exponent_diff = out_denormal_act_exponent -
104 act_exponent = exponent - bias;
105 if(act_exponent <= out_denormal_act_exponent)
112 exponent_diff = out_denormal_act_exponent - act_exponent;
120 mantissa += (1 << in_mant);
123 bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) ==
124 (1 << (in_mant - out_mant + exponent_diff - 1));
130 if(exponent_diff > 0)
131 mantissa >>= exponent_diff;
132 else if(exponent_diff == -1)
133 mantissa <<= -exponent_diff;
134 bool implicit_one = mantissa & (1 << in_mant);
137 (act_exponent + exponent_diff) + out_bias - (implicit_one ? 0 : 1);
142 (1 << (in_mant - out_mant));
143 mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
146 if(out_exponent == 0)
148 if((1 << in_mant) & mantissa)
156 if((1 << (in_mant + 1)) & mantissa)
164 mantissa >>= (in_mant - out_mant);
166 if(out_exponent > max_exp)
170 mantissa = (1 << out_mant) - 1;
171 out_exponent = max_exp;
180 if(out_exponent == 0 && mantissa == 0)
181 return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
182 mantissa &= (1 << out_mant) - 1;
183 return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
186 template <
typename X,
typename Y,
bool negative_zero_nan>
191 constexpr
int in_mant = NumericUtils<X>::mant;
195 constexpr
int out_mant = NumericUtils<Y>::mant;
198 constexpr X nan_code = 0x80;
199 using T_bitwise =
typename NumericUtils<Y>::bitwise_type;
201 constexpr T_bitwise Inf_bitwise = NumericUtils<Y>::Inf;
202 constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
203 constexpr T_bitwise NaN_bitwise = NumericUtils<Y>::NaN;
204 constexpr T_bitwise Neg0_bitwise = NumericUtils<Y>::Neg0;
206 constexpr Y Inf = bit_cast<Y>(Inf_bitwise);
207 constexpr Y NegInf = bit_cast<Y>(NegInf_bitwise);
208 constexpr Y NaN = bit_cast<Y>(NaN_bitwise);
209 constexpr Y Neg0 = bit_cast<Y>(Neg0_bitwise);
213 return static_cast<Y
>(0);
216 uint32_t sign = x >> (in_exp + in_mant);
217 uint32_t mantissa = x & ((1 << in_mant) - 1);
218 int exponent = (x & 0x7F) >> in_mant;
220 constexpr
int exp_low_cutoff =
221 (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
224 if constexpr(negative_zero_nan)
233 if(exponent == ((1 << in_exp) - 1))
234 return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
237 if constexpr((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) &&
242 return bit_cast<Y>(retval);
249 int sh = 1 +
clz(mantissa) - (32 - in_mant);
252 mantissa &= ((1 << in_mant) - 1);
254 exponent += exp_low_cutoff - 1;
255 mantissa <<= out_mant - in_mant;
260 mantissa |= 1 << out_mant;
261 mantissa >>= 1 - exponent;
265 retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
266 return bit_cast<Y>(retval);
271 template <
typename X,
typename Y,
bool negative_zero_nan,
bool clip,
bool stoch>
277 static_assert(is_half || is_float,
"Only half and float can be casted.");
279 return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
282 template <
typename X,
typename Y,
bool negative_zero_nan>
288 static_assert(is_half || is_float,
"only half and float are supported.");
290 return run_cast_from_f8<X, Y, negative_zero_nan>(x);
__host__ T exp(T x)
Definition: math_v2.hpp:391
Definition: check_err.hpp:24
__host__ __device__ Y cast_from_f8(X x)
Definition: f8_utils.hpp:283
__host__ __device__ Y cast_to_f8(X x, uint32_t rng)
Definition: f8_utils.hpp:272
CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng=0)
Definition: float8.hpp:250
CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
Definition: float8.hpp:476
f8_rounding_mode
Definition: f8_utils.hpp:14
__host__ int clz(uint32_t x)
Definition: f8_utils.hpp:19
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned int uint32_t
Definition: stdint.h:126