13 #include <type_traits>
17 #if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
18 #define CK_TILE_FP8_CVT_DEVICE 1
20 #define CK_TILE_FP8_CVT_DEVICE 0
70 #if CK_TILE_USE_CUSTOM_DATA_TYPE
71 struct alignas(1) float8_e4m3_t
73 static constexpr
int exponent = 4;
74 static constexpr
int mantissa = 3;
75 #if CK_TILE_USE_OCP_FP8
76 static constexpr
int bias = 7;
78 static constexpr
int bias = 8;
84 static constexpr float8_e4m3_t
bit_cast(raw_type x)
92 constexpr float8_e4m3_t() : data() {}
96 explicit constexpr float8_e4m3_t(
const float& x) : data(
float_to_fp8_raw(x)) {}
100 explicit constexpr float8_e4m3_t(
const int& x) : data(
float_to_fp8_raw(static_cast<float>(x)))
106 explicit constexpr float8_e4m3_t(
const unsigned int& x)
113 explicit constexpr
operator float()
const {
return fp8_to_float_raw(data); }
117 explicit constexpr
operator int()
const {
return static_cast<int>(
fp8_to_float_raw(data)); }
121 constexpr raw_type& get() {
return data; }
124 constexpr raw_type get()
const {
return data; }
126 using fp8_t = float8_e4m3_t;
127 using fp8_raw_t =
typename fp8_t::raw_type;
129 struct alignas(1) float8_e5m2_t
131 static constexpr
int exponent = 5;
132 static constexpr
int mantissa = 2;
133 #if CK_TILE_USE_OCP_FP8
134 static constexpr
int bias = 15;
136 static constexpr
int bias = 16;
142 static constexpr float8_e5m2_t
bit_cast(raw_type x)
150 constexpr float8_e5m2_t() : data() {}
154 explicit constexpr float8_e5m2_t(
const float& x) : data(
float_to_bf8_raw(x)) {}
158 explicit constexpr float8_e5m2_t(
const int& x) : data(
float_to_bf8_raw(static_cast<float>(x)))
164 explicit constexpr float8_e5m2_t(
const unsigned int& x)
171 explicit constexpr
operator float()
const {
return bf8_to_float_raw(data); }
175 explicit constexpr
operator int()
const {
return static_cast<int>(
bf8_to_float_raw(data)); }
179 constexpr raw_type& get() {
return data; }
182 constexpr raw_type get()
const {
return data; }
184 using bf8_t = float8_e5m2_t;
185 using bf8_raw_t =
typename bf8_t::raw_type;
191 struct native_t<
fp8_t>
193 using type = _BitInt(8);
197 struct native_t<
bf8_t>
199 using type =
unsigned _BitInt(8);
215 static constexpr
int exp = 4;
216 static constexpr
int mant = 3;
217 #if CK_TILE_USE_OCP_FP8
218 static constexpr
int bias = 7;
221 static constexpr
int bias = 8;
233 static constexpr
int exp = 5;
234 static constexpr
int mant = 2;
235 #if CK_TILE_USE_OCP_FP8
236 static constexpr
int bias = 15;
239 static constexpr
int bias = 16;
249 template <
typename SrcT,
typename DstT,
bool clip = true,
bool stoch = false>
253 "DstT type must be fp8 or bf8.");
257 static_assert(is_half || is_float,
"Only half and float can be cast to f8");
263 constexpr
bool is_fnuz =
274 SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
276 unsigned int head, mantissa;
283 sign = head >> (SrcT_exp + SrcT_mant);
285 unsigned int signed_inf = 0;
286 unsigned int nan = 0;
287 if constexpr(is_fnuz)
289 signed_inf = clip ? ((sign << (DstT_exp + DstT_mant)) + 0x7f) : 0x80;
294 if constexpr(DstT_exp == 4)
296 signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7e : 0x7f);
300 signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7b : 0x7c);
302 nan = (sign << (DstT_exp + DstT_mant)) + 0x7f;
305 unsigned int ifmax = 0;
306 if constexpr(is_float)
308 if constexpr(DstT_exp == 5)
314 if constexpr(is_fnuz)
324 else if constexpr(is_half)
326 if constexpr(DstT_exp == 5)
332 if constexpr(is_fnuz)
344 if((src_bitwise & fInf) == fInf)
346 return mantissa != 0 ? nan : signed_inf;
349 if((src_bitwise & abs_mask) > ifmax)
362 constexpr
int f8_denormal_act_exponent = 1 - DstT_bias;
367 int act_exponent, f8_exponent, exponent_diff;
378 act_exponent = exponent - bias + 1;
379 exponent_diff = f8_denormal_act_exponent -
384 act_exponent = exponent - bias;
385 if(act_exponent <= f8_denormal_act_exponent)
392 exponent_diff = f8_denormal_act_exponent - act_exponent;
400 mantissa += (1u << SrcT_mant);
404 if(exponent_diff > DstT_mant)
406 return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
408 bool midpoint = (mantissa & ((1u << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
409 (1u << (SrcT_mant - DstT_mant + exponent_diff - 1));
417 if(exponent_diff > 0)
418 mantissa >>= exponent_diff;
419 else if(exponent_diff == -1)
420 mantissa <<= -exponent_diff;
421 bool implicit_one = mantissa & (1u << SrcT_mant);
425 (act_exponent + exponent_diff) + DstT_bias - (implicit_one ? 0 : 1);
428 unsigned int drop_mask = (1u << (SrcT_mant - DstT_mant)) - 1;
431 (1u << (SrcT_mant - DstT_mant));
433 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1u) : mantissa)) & drop_mask;
438 if((1u << SrcT_mant) & mantissa)
445 if((1u << (SrcT_mant + 1)) & mantissa)
452 mantissa >>= (SrcT_mant - DstT_mant);
455 const int max_exp = (1 << DstT_exp) - 1;
456 if(f8_exponent > max_exp)
460 mantissa = (1 << DstT_mant) - 1;
461 f8_exponent = max_exp;
469 if(f8_exponent == 0 && mantissa == 0)
470 return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
471 mantissa &= (1 << DstT_mant) - 1;
472 return (sign << (DstT_exp + DstT_mant)) | (f8_exponent << DstT_mant) | mantissa;
475 template <
typename SrcT,
typename DstT,
bool clip = true>
479 "SrcT type must be fp8 or bf8.");
483 constexpr
bool is_fnuz =
489 static_assert(is_half || is_float,
"DstT type must be half_t or float.");
500 DstT fmax{0}, fmin{0};
502 if constexpr(is_half)
507 else if constexpr(is_float)
518 unsigned int sign = x >> (SrcT_exp + SrcT_mant);
519 unsigned int mantissa = x & ((1 << SrcT_mant) - 1);
520 int exponent = (x & SrcT_abs_mask) >> SrcT_mant;
521 if constexpr(is_fnuz)
523 if((x & 0xff) == 0x80)
534 if constexpr(SrcT_exp == 4)
536 if((x & 0x7F) == 0x7F)
541 else if((x & 0x7C) == 0x7C)
547 return sign ? fmin : fmax;
549 return sign ? fNegInf : fInf;
557 if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
560 return bit_cast<DstT>(retval);
563 const int exp_low_cutoff =
564 (1 << (DstT_exp - 1)) - (1 << (SrcT_exp - 1)) + 1 - (is_fnuz ? 1 : 0);
569 int sh = 1 +
clz(mantissa) - (32 - SrcT_mant);
572 mantissa &= ((1ull << SrcT_mant) - 1);
574 exponent += exp_low_cutoff - 1;
575 mantissa <<= DstT_mant - SrcT_mant;
580 mantissa |= 1 << DstT_mant;
581 mantissa >>= 1 - exponent;
585 retval = (sign << (DstT_exp + DstT_mant)) | (exponent << DstT_mant) | mantissa;
587 return bit_cast<DstT>(retval);
590 template <
typename X,
typename Y,
bool clip,
bool stoch>
593 return bit_cast<Y>(run_cast_to_f8<X, Y, clip, stoch>(x, rng));
596 #if CK_TILE_FP8_CVT_DEVICE
600 template <fp8_
interpretation
interpret,
bool saturate,
bool stochastic_rounding = false>
608 unsigned char i8val[4];
611 unsigned int ival = 0;
614 if constexpr(saturate)
618 if((val.i32val & 0x7F800000) != 0x7F800000)
620 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
625 if((val.i32val & 0x7F800000) != 0x7F800000)
627 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
632 if((val.i32val & 0x7F800000) != 0x7F800000)
634 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
639 if constexpr(stochastic_rounding)
643 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
644 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0);
646 i8data = val.i8val[0];
652 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
false)
653 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
658 i8data = val.i8val[0];
679 template <
typename SrcT,
typename DstT>
682 constexpr
bool clip =
true;
683 constexpr
int seed = 42;
685 #if CK_TILE_FP8_CVT_DEVICE
686 return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip,
true>(x, rng);
688 return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
689 impl::cast_to_f8<SrcT, DstT, clip, true>(x, rng));
705 template <
typename SrcT,
typename DstT>
708 constexpr
bool clip =
true;
709 #if CK_TILE_FP8_CVT_DEVICE
710 return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip,
false>(x, 0);
712 return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
713 impl::cast_to_f8<SrcT, DstT, clip, false>(x, 0));
717 template <fp8_rounding_mode rounding>
722 return float_to_fp8_rtn_raw<float, fp8_t>(x);
726 return float_to_fp8_sr_raw<float, fp8_t>(x);
734 template <fp8_rounding_mode rounding>
739 return float_to_fp8_rtn_raw<float, bf8_t>(x);
743 return float_to_fp8_sr_raw<float, bf8_t>(x);
753 #if CK_TILE_FP8_CVT_DEVICE
756 fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
760 return impl::run_cast_from_f8<fp8_t, float>(bit_cast<fp8_t>(x));
766 #if CK_TILE_FP8_CVT_DEVICE
769 fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
773 return impl::run_cast_from_f8<bf8_t, float>(bit_cast<bf8_t>(x));
796 #if CK_TILE_USE_OCP_FP8
798 struct numeric<
fp8_t>
803 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x08));
809 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0xfe));
815 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x7e));
822 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x20));
829 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x18));
835 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x7F));
841 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0xFF));
847 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x01));
852 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0));
857 struct numeric<
bf8_t>
862 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x04));
868 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0xfb));
874 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x7b));
880 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x34));
887 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x30));
893 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x7c));
899 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x7F));
905 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0xFF));
911 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x01));
916 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0));
926 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x08));
932 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0xff));
938 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x7f));
944 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x20));
954 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x30));
960 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x80));
966 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x80));
972 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x80));
978 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x01));
983 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0));
993 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x04));
999 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0xff));
1005 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x7f));
1011 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x34));
1021 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x38));
1027 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x80));
1033 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x80));
1039 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x80));
1045 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x01));
1050 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0));
1055 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1061 template <
typename T>
1064 static_assert(std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>,
1065 "Only fp8_t and bf8_t are supported");
1072 uint8_t xx = bit_cast<fp8_raw_t>(x);
1074 #if CK_TILE_USE_OCP_FP8
1075 return (xx & 0x7f) == 0x7f;
1080 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1082 fp8_t sqrt(
fp8_t x) {
return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(
static_cast<float>(x))); };
1085 fp8_t exp(
fp8_t x) {
return static_cast<fp8_t>(__ocml_exp_f32(
static_cast<float>(x))); };
1097 uint8_t xx = bit_cast<bf8_raw_t>(x);
1099 #if CK_TILE_USE_OCP_FP8
1100 return (xx & 0x7f) > 0x7c;
1106 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1108 bf8_t sqrt(
bf8_t x) {
return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(
static_cast<float>(x))); };
1111 bf8_t exp(
bf8_t x) {
return static_cast<bf8_t>(__ocml_exp_f32(
static_cast<float>(x))); };
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_FLOAT_TO_FP8_DEFAULT
Definition: config.hpp:79
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng=0)
Definition: float8.hpp:250
CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
Definition: float8.hpp:476
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
Definition: float8.hpp:591
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:432
fp8_interpretation
FP8 interpretation used in conversion algorithms.
Definition: float8.hpp:38
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant< rounding >={})
Definition: float8.hpp:778
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:751
constexpr CK_TILE_HOST_DEVICE Y bit_cast(const X &x)
Definition: bit_cast.hpp:11
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:764
fp8_rounding_mode
Definition: float8.hpp:29
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:417
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant< rounding >={})
Definition: float8.hpp:718
uint8_t fp8_raw_t
Definition: float8.hpp:205
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x)
Definition: float8.hpp:791
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_sr_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with stochastic rounding.
Definition: float8.hpp:680
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:423
CK_TILE_HOST int clz(uint32_t x)
Definition: math.hpp:264
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:404
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
uint8_t bf8_raw_t
Definition: float8.hpp:207
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant< rounding >={})
Definition: float8.hpp:784
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition: bfloat16.hpp:410
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_rtn_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with rounding to nearest ev...
Definition: float8.hpp:706
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x)
Definition: float8.hpp:789
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant< rounding >={})
Definition: float8.hpp:735
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:429
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
_W64 unsigned int uintptr_t
Definition: stdint.h:165
unsigned int uint32_t
Definition: stdint.h:126
unsigned char uint8_t
Definition: stdint.h:124
Definition: integral_constant.hpp:13
remove_cvref_t< T > type
Definition: vector_type.hpp:26
static constexpr CK_TILE_HOST_DEVICE bf8_t min()
Definition: float8.hpp:991
static constexpr CK_TILE_HOST_DEVICE bf8_t quiet_NaN()
Definition: float8.hpp:1031
static constexpr CK_TILE_HOST_DEVICE bf8_t lowest()
Definition: float8.hpp:997
static constexpr CK_TILE_HOST_DEVICE bf8_t round_error()
Definition: float8.hpp:1019
static constexpr CK_TILE_HOST_DEVICE bf8_t signaling_NaN()
Definition: float8.hpp:1037
static constexpr CK_TILE_HOST_DEVICE bf8_t denorm_min()
Definition: float8.hpp:1043
static constexpr CK_TILE_HOST_DEVICE bf8_t epsilon()
Definition: float8.hpp:1009
static constexpr CK_TILE_HOST_DEVICE bf8_t infinity()
Definition: float8.hpp:1025
static constexpr CK_TILE_HOST_DEVICE bf8_t max()
Definition: float8.hpp:1003
static constexpr CK_TILE_HOST_DEVICE bf8_t zero()
Definition: float8.hpp:1048
static constexpr CK_TILE_HOST_DEVICE fp8_t signaling_NaN()
Definition: float8.hpp:970
static constexpr CK_TILE_HOST_DEVICE fp8_t zero()
Definition: float8.hpp:981
static constexpr CK_TILE_HOST_DEVICE fp8_t min()
Definition: float8.hpp:924
static constexpr CK_TILE_HOST_DEVICE fp8_t lowest()
Definition: float8.hpp:930
static constexpr CK_TILE_HOST_DEVICE fp8_t epsilon()
Definition: float8.hpp:942
static constexpr CK_TILE_HOST_DEVICE fp8_t quiet_NaN()
Definition: float8.hpp:964
static constexpr CK_TILE_HOST_DEVICE fp8_t max()
Definition: float8.hpp:936
static constexpr CK_TILE_HOST_DEVICE fp8_t denorm_min()
Definition: float8.hpp:976
static constexpr CK_TILE_HOST_DEVICE fp8_t round_error()
Definition: float8.hpp:952
static constexpr CK_TILE_HOST_DEVICE fp8_t infinity()
Definition: float8.hpp:958
bf8_raw_t bitwise_type
Definition: float8.hpp:231
fp8_raw_t bitwise_type
Definition: float8.hpp:213
Definition: numeric.hpp:81
static constexpr int PackedSize
Definition: numeric.hpp:82
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T lowest()
Definition: numeric.hpp:23
static constexpr CK_TILE_HOST_DEVICE T min()
Definition: numeric.hpp:20
static constexpr CK_TILE_HOST_DEVICE T quiet_NaN()
Definition: numeric.hpp:41
static constexpr CK_TILE_HOST_DEVICE T signaling_NaN()
Definition: numeric.hpp:47
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26
static constexpr CK_TILE_HOST_DEVICE T round_error()
Definition: numeric.hpp:32
static constexpr CK_TILE_HOST_DEVICE T zero()
Definition: numeric.hpp:58
static constexpr CK_TILE_HOST_DEVICE T denorm_min()
Definition: numeric.hpp:53
static constexpr CK_TILE_HOST_DEVICE T epsilon()
Definition: numeric.hpp:29
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38
Definition: random.hpp:17
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:106