13 #include <type_traits>
17 #if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
18 #define CK_TILE_FP8_CVT_DEVICE 1
20 #define CK_TILE_FP8_CVT_DEVICE 0
70 #if CK_TILE_USE_CUSTOM_DATA_TYPE
71 struct alignas(1) float8_e4m3_t
73 static constexpr
int exponent = 4;
74 static constexpr
int mantissa = 3;
75 #if CK_TILE_USE_OCP_FP8
76 static constexpr
int bias = 7;
78 static constexpr
int bias = 8;
80 using raw_type = uint8_t;
84 static constexpr float8_e4m3_t
bit_cast(raw_type x)
92 constexpr float8_e4m3_t() : data() {}
96 explicit constexpr float8_e4m3_t(
const float& x) : data(
float_to_fp8_raw(x)) {}
100 explicit constexpr float8_e4m3_t(
const int& x) : data(
float_to_fp8_raw(static_cast<float>(x)))
106 explicit constexpr float8_e4m3_t(
const unsigned int& x)
113 explicit constexpr
operator float()
const {
return fp8_to_float_raw(data); }
117 explicit constexpr
operator int()
const {
return static_cast<int>(
fp8_to_float_raw(data)); }
121 constexpr raw_type& get() {
return data; }
124 constexpr raw_type get()
const {
return data; }
126 using fp8_t = float8_e4m3_t;
127 using fp8_raw_t =
typename fp8_t::raw_type;
129 struct alignas(1) float8_e5m2_t
131 static constexpr
int exponent = 5;
132 static constexpr
int mantissa = 2;
133 #if CK_TILE_USE_OCP_FP8
134 static constexpr
int bias = 15;
136 static constexpr
int bias = 16;
138 using raw_type = uint8_t;
142 static constexpr float8_e5m2_t
bit_cast(raw_type x)
150 constexpr float8_e5m2_t() : data() {}
154 explicit constexpr float8_e5m2_t(
const float& x) : data(
float_to_bf8_raw(x)) {}
158 explicit constexpr float8_e5m2_t(
const int& x) : data(
float_to_bf8_raw(static_cast<float>(x)))
164 explicit constexpr float8_e5m2_t(
const unsigned int& x)
171 explicit constexpr
operator float()
const {
return bf8_to_float_raw(data); }
175 explicit constexpr
operator int()
const {
return static_cast<int>(
bf8_to_float_raw(data)); }
179 constexpr raw_type& get() {
return data; }
182 constexpr raw_type get()
const {
return data; }
184 using bf8_t = float8_e5m2_t;
185 using bf8_raw_t =
typename bf8_t::raw_type;
191 struct native_t<
fp8_t>
193 using type = _BitInt(8);
197 struct native_t<
bf8_t>
199 using type =
unsigned _BitInt(8);
210 template <
typename T>
218 static constexpr
int exp = 4;
219 static constexpr
int mant = 3;
220 #if CK_TILE_USE_OCP_FP8
221 static constexpr
int bias = 7;
224 static constexpr
int bias = 8;
227 static constexpr uint8_t abs_mask = 0x7F;
235 static constexpr
int exp = 5;
236 static constexpr
int mant = 2;
237 #if CK_TILE_USE_OCP_FP8
238 static constexpr
int bias = 15;
241 static constexpr
int bias = 16;
244 static constexpr uint8_t abs_mask = 0x7F;
250 template <
typename SrcT,
typename DstT,
bool clip = true,
bool stoch = false>
253 static_assert(std::is_same<DstT, fp8_t>::value || std::is_same<DstT, bf8_t>::value,
254 "DstT type must be fp8 or bf8.");
256 constexpr
bool is_half = std::is_same<SrcT, half_t>::value;
257 constexpr
bool is_float = std::is_same<SrcT, float>::value;
258 static_assert(is_half || is_float,
"Only half and float can be cast to f8");
263 constexpr
bool is_fnuz =
271 SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
273 unsigned long long head, mantissa;
276 unsigned long long fInf, abs_mask;
281 sign = head >> (SrcT_exp + SrcT_mant);
286 unsigned int signed_inf = 0;
287 unsigned int nan = 0;
288 if constexpr(is_fnuz)
290 signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
295 if constexpr(DstT_exp == 4)
297 signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
301 signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
303 nan = (sign << 7) + 0x7f;
306 unsigned long long ifmax = 0;
307 if constexpr(is_float)
309 if constexpr(DstT_exp == 5)
315 if constexpr(is_fnuz)
325 else if constexpr(is_half)
327 if constexpr(DstT_exp == 5)
333 if constexpr(is_fnuz)
345 if((src_bitwise & fInf) == fInf)
347 if constexpr(is_fnuz)
350 return mantissa != 0 ? nan : signed_inf;
353 if((src_bitwise & abs_mask) > ifmax)
371 const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0);
372 const int f8_denormal_act_exponent = 1 - f8_bias;
377 int act_exponent, f8_exponent, exponent_diff;
388 act_exponent = exponent - bias + 1;
389 exponent_diff = f8_denormal_act_exponent -
394 act_exponent = exponent - bias;
395 if(act_exponent <= f8_denormal_act_exponent)
402 exponent_diff = f8_denormal_act_exponent - act_exponent;
410 mantissa += (1ull << SrcT_mant);
413 bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
414 (1ull << (SrcT_mant - DstT_mant + exponent_diff - 1));
422 if(exponent_diff > 0)
423 mantissa >>= exponent_diff;
424 else if(exponent_diff == -1)
425 mantissa <<= -exponent_diff;
426 bool implicit_one = mantissa & (1ull << SrcT_mant);
430 (act_exponent + exponent_diff) + f8_bias - (implicit_one ? 0 : 1);
433 unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1;
435 mantissa & (1ull << (SrcT_mant -
438 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
443 if((1ull << SrcT_mant) & mantissa)
450 if((1ull << (SrcT_mant + 1)) & mantissa)
457 mantissa >>= (SrcT_mant - DstT_mant);
460 const int max_exp = (1 << DstT_exp) - 1;
461 if(f8_exponent > max_exp)
465 mantissa = (1 << DstT_mant) - 1;
466 f8_exponent = max_exp;
474 if(f8_exponent == 0 && mantissa == 0)
475 return is_fnuz ? 0 : (sign << 7);
476 mantissa &= (1 << DstT_mant) - 1;
477 return (sign << 7) | (f8_exponent << DstT_mant) | mantissa;
480 template <
typename SrcT,
typename DstT,
bool clip = true>
483 static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
484 "SrcT type must be fp8 or bf8.");
487 constexpr
bool is_fnuz =
491 constexpr
bool is_half = std::is_same<DstT, half_t>::value;
492 constexpr
bool is_float = std::is_same<DstT, float>::value;
493 static_assert(is_half || is_float,
"DstT type must be half_t or float.");
504 DstT fmax{0}, fmin{0};
506 if constexpr(is_half)
511 else if constexpr(is_float)
522 unsigned long long sign = x >> 7;
523 unsigned long long mantissa = x & ((1 << SrcT_mant) - 1);
524 int exponent = (x & 0x7F) >> SrcT_mant;
525 if constexpr(is_fnuz)
538 if constexpr(SrcT_exp == 4)
540 if((x & 0x7F) == 0x7F)
545 else if((x & 0x7C) == 0x7C)
551 return sign ? fmin : fmax;
553 return sign ? fNegInf : fInf;
561 if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
564 return bit_cast<DstT>(retval);
567 const int exp_low_cutoff =
568 (1 << (DstT_exp - 1)) - (1 << (SrcT_exp - 1)) + 1 - (is_fnuz ? 1 : 0);
573 int sh = 1 +
clz(mantissa) - (32 - SrcT_mant);
576 mantissa &= ((1ull << SrcT_mant) - 1);
578 exponent += exp_low_cutoff - 1;
579 mantissa <<= DstT_mant - SrcT_mant;
584 mantissa |= 1 << DstT_mant;
585 mantissa >>= 1 - exponent;
589 retval = (sign << (DstT_exp + DstT_mant)) | (exponent << DstT_mant) | mantissa;
591 return bit_cast<DstT>(retval);
594 template <
typename X,
typename Y,
bool clip,
bool stoch>
597 return bit_cast<Y>(run_cast_to_f8<X, Y, clip, stoch>(x, rng));
600 #if CK_TILE_FP8_CVT_DEVICE
604 template <fp8_
interpretation
interpret,
bool saturate,
bool stochastic_rounding = false>
605 CK_TILE_DEVICE uint8_t cast_to_f8_from_f32(
float v,
unsigned int rng = 0)
612 unsigned char i8val[4];
615 unsigned int ival = 0;
618 if constexpr(saturate)
622 if((val.i32val & 0x7F800000) != 0x7F800000)
624 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
629 if((val.i32val & 0x7F800000) != 0x7F800000)
631 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
636 if((val.i32val & 0x7F800000) != 0x7F800000)
638 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
643 if constexpr(stochastic_rounding)
647 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
648 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0);
650 i8data = val.i8val[0];
656 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
false)
657 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
662 i8data = val.i8val[0];
683 template <
typename SrcT,
typename DstT>
686 constexpr
bool clip =
true;
687 constexpr
int seed = 42;
689 #if CK_TILE_FP8_CVT_DEVICE
690 return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip,
true>(x, rng);
692 return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
693 impl::cast_to_f8<SrcT, DstT, clip, true>(x, rng));
709 template <
typename SrcT,
typename DstT>
712 constexpr
bool clip =
true;
713 #if CK_TILE_FP8_CVT_DEVICE
714 return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip,
false>(x, 0);
716 return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
717 impl::cast_to_f8<SrcT, DstT, clip, false>(x, 0));
721 template <fp8_rounding_mode rounding>
726 return float_to_fp8_rtn_raw<float, fp8_t>(x);
730 return float_to_fp8_sr_raw<float, fp8_t>(x);
738 template <fp8_rounding_mode rounding>
743 return float_to_fp8_rtn_raw<float, bf8_t>(x);
747 return float_to_fp8_sr_raw<float, bf8_t>(x);
757 #if CK_TILE_FP8_CVT_DEVICE
759 uint32_t i32val =
static_cast<uint32_t
>(x);
760 fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
764 return impl::run_cast_from_f8<fp8_t, float>(bit_cast<fp8_t>(x));
770 #if CK_TILE_FP8_CVT_DEVICE
772 uint32_t i32val =
static_cast<uint32_t
>(x);
773 fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
777 return impl::run_cast_from_f8<bf8_t, float>(bit_cast<bf8_t>(x));
800 #if CK_TILE_USE_OCP_FP8
802 struct numeric<
fp8_t>
807 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x08));
813 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0xfe));
819 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x7e));
826 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x20));
833 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x18));
839 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x7F));
845 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0xFF));
851 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x01));
856 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0));
861 struct numeric<
bf8_t>
866 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x04));
872 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0xfb));
878 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x7b));
884 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x34));
891 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x30));
897 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x7c));
903 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x7F));
909 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0xFF));
915 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x01));
920 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0));
930 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x08));
936 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0xff));
942 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x7f));
948 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x20));
958 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x30));
964 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x80));
970 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x80));
976 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x80));
982 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x01));
987 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0));
997 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x04));
1003 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0xff));
1009 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x7f));
1015 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x34));
1025 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x38));
1031 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x80));
1037 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x80));
1043 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x80));
1049 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x01));
1054 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0));
1059 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1065 template <
typename T>
1068 static_assert(std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>,
1069 "Only fp8_t and bf8_t are supported");
1076 uint8_t xx = bit_cast<fp8_raw_t>(x);
1078 #if CK_TILE_USE_OCP_FP8
1079 return (xx & 0x7f) == 0x7f;
1084 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1086 fp8_t sqrt(
fp8_t x) {
return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(
static_cast<float>(x))); };
1089 fp8_t exp(
fp8_t x) {
return static_cast<fp8_t>(__ocml_exp_f32(
static_cast<float>(x))); };
1101 uint8_t xx = bit_cast<bf8_raw_t>(x);
1103 #if CK_TILE_USE_OCP_FP8
1104 return (xx & 0x7f) > 0x7c;
1110 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1112 bf8_t sqrt(
bf8_t x) {
return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(
static_cast<float>(x))); };
1115 bf8_t exp(
bf8_t x) {
return static_cast<bf8_t>(__ocml_exp_f32(
static_cast<float>(x))); };
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_FLOAT_TO_FP8_DEFAULT
Definition: config.hpp:77
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng=0)
Definition: float8.hpp:251
CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
Definition: float8.hpp:481
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
Definition: float8.hpp:595
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:423
fp8_interpretation
FP8 interpretation used in conversion algorithms.
Definition: float8.hpp:38
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant< rounding >={})
Definition: float8.hpp:782
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:755
constexpr CK_TILE_HOST_DEVICE Y bit_cast(const X &x)
Definition: bit_cast.hpp:11
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:768
fp8_rounding_mode
Definition: float8.hpp:29
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:408
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant< rounding >={})
Definition: float8.hpp:722
uint8_t fp8_raw_t
Definition: float8.hpp:205
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x)
Definition: float8.hpp:795
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_sr_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with stochastic rounding.
Definition: float8.hpp:684
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:414
CK_TILE_HOST int clz(uint32_t x)
Definition: math.hpp:264
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:395
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
uint8_t bf8_raw_t
Definition: float8.hpp:207
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant< rounding >={})
Definition: float8.hpp:788
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition: bfloat16.hpp:401
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_rtn_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with rounding to nearest ev...
Definition: float8.hpp:710
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x)
Definition: float8.hpp:793
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant< rounding >={})
Definition: float8.hpp:739
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:420
Definition: integral_constant.hpp:13
remove_cvref_t< T > type
Definition: vector_type.hpp:25
static constexpr CK_TILE_HOST_DEVICE bf8_t min()
Definition: float8.hpp:995
static constexpr CK_TILE_HOST_DEVICE bf8_t quiet_NaN()
Definition: float8.hpp:1035
static constexpr CK_TILE_HOST_DEVICE bf8_t lowest()
Definition: float8.hpp:1001
static constexpr CK_TILE_HOST_DEVICE bf8_t round_error()
Definition: float8.hpp:1023
static constexpr CK_TILE_HOST_DEVICE bf8_t signaling_NaN()
Definition: float8.hpp:1041
static constexpr CK_TILE_HOST_DEVICE bf8_t denorm_min()
Definition: float8.hpp:1047
static constexpr CK_TILE_HOST_DEVICE bf8_t epsilon()
Definition: float8.hpp:1013
static constexpr CK_TILE_HOST_DEVICE bf8_t infinity()
Definition: float8.hpp:1029
static constexpr CK_TILE_HOST_DEVICE bf8_t max()
Definition: float8.hpp:1007
static constexpr CK_TILE_HOST_DEVICE bf8_t zero()
Definition: float8.hpp:1052
static constexpr CK_TILE_HOST_DEVICE fp8_t signaling_NaN()
Definition: float8.hpp:974
static constexpr CK_TILE_HOST_DEVICE fp8_t zero()
Definition: float8.hpp:985
static constexpr CK_TILE_HOST_DEVICE fp8_t min()
Definition: float8.hpp:928
static constexpr CK_TILE_HOST_DEVICE fp8_t lowest()
Definition: float8.hpp:934
static constexpr CK_TILE_HOST_DEVICE fp8_t epsilon()
Definition: float8.hpp:946
static constexpr CK_TILE_HOST_DEVICE fp8_t quiet_NaN()
Definition: float8.hpp:968
static constexpr CK_TILE_HOST_DEVICE fp8_t max()
Definition: float8.hpp:940
static constexpr CK_TILE_HOST_DEVICE fp8_t denorm_min()
Definition: float8.hpp:980
static constexpr CK_TILE_HOST_DEVICE fp8_t round_error()
Definition: float8.hpp:956
static constexpr CK_TILE_HOST_DEVICE fp8_t infinity()
Definition: float8.hpp:962
bf8_raw_t bitwise_type
Definition: float8.hpp:233
fp8_raw_t bitwise_type
Definition: float8.hpp:216
Definition: bfloat16.hpp:380
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T lowest()
Definition: numeric.hpp:23
static constexpr CK_TILE_HOST_DEVICE T min()
Definition: numeric.hpp:20
static constexpr CK_TILE_HOST_DEVICE T quiet_NaN()
Definition: numeric.hpp:41
static constexpr CK_TILE_HOST_DEVICE T signaling_NaN()
Definition: numeric.hpp:47
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26
static constexpr CK_TILE_HOST_DEVICE T round_error()
Definition: numeric.hpp:32
static constexpr CK_TILE_HOST_DEVICE T zero()
Definition: numeric.hpp:58
static constexpr CK_TILE_HOST_DEVICE T denorm_min()
Definition: numeric.hpp:53
static constexpr CK_TILE_HOST_DEVICE T epsilon()
Definition: numeric.hpp:29
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38
Definition: random.hpp:17
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:102