11 #ifdef CK_USE_FNUZ_FP8
12 #define CK_USE_FNUZ_FP8 1
14 #define CK_USE_FNUZ_FP8 0
18 #define CK_USE_OCP_FP8 1
20 #define CK_USE_OCP_FP8 0
23 #if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
24 defined(__gfx1201__) || defined(__gfx950__)) && \
25 __HIP_DEVICE_COMPILE__
26 #define CK_FP8_CVT_FAST_PATH 1
28 #define CK_FP8_CVT_FAST_PATH 0
31 #if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
32 #define CK_OCP_FP8_CVT_FAST_PATH 1
34 #define CK_OCP_FP8_CVT_FAST_PATH 0
67 typedef float float2_t __attribute__((ext_vector_type(2)));
69 __host__ __device__
static inline constexpr
bool fnuz_f8_is_nan(
f8_fnuz_t a)
71 return static_cast<unsigned char>(a) == 0x80;
73 __host__ __device__
static inline constexpr
bool fnuz_bf8_is_nan(
bf8_fnuz_t a)
75 return static_cast<unsigned char>(a) == 0x80;
78 __host__ __device__
static inline constexpr
bool ocp_f8_is_nan(
fp8_storage_t a)
80 return (a & 0x7f) == 0x7f;
82 __host__ __device__
static inline constexpr
bool ocp_bf8_is_nan(
fp8_storage_t a)
84 return (a & 0x7f) > 0x7c;
90 template <
typename T,
int wm,
int we,
bool is_fnuz,
bool clip = false>
91 __host__ __device__
static inline T cast_from_f8(
fp8_storage_t x)
93 constexpr
bool is_half = __hip_internal::is_same<T, _Float16>::value;
94 constexpr
bool is_float = __hip_internal::is_same<T, float>::value;
95 constexpr
bool is_double = __hip_internal::is_same<T, double>::value;
96 static_assert(is_half || is_float || is_double,
"only half, float and double are supported");
98 constexpr
int weo = is_half ? 5 : (is_float ? 8 : 11);
99 constexpr
int wmo = is_half ? 10 : (is_float ? 23 : 52);
101 T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
102 if constexpr(is_half)
104 const unsigned short int ihInf = 0x7C00;
105 const unsigned short int ihNegInf = 0xFC00;
106 const unsigned short int ihNaN = 0x7C01;
107 const unsigned short int ihNeg0 = 0x8000;
109 const unsigned short int ifmax = 0x7B00;
110 const unsigned short int ifmin = 0xFB00;
112 fInf = bit_cast<_Float16>(ihInf);
113 fNegInf = bit_cast<_Float16>(ihNegInf);
114 fNaN = bit_cast<_Float16>(ihNaN);
115 fNeg0 = bit_cast<_Float16>(ihNeg0);
116 fmax = bit_cast<_Float16>(ifmax);
117 fmin = bit_cast<_Float16>(ifmin);
119 else if constexpr(is_float)
121 const unsigned int ifInf = 0x7F800000;
122 const unsigned int ifNegInf = 0xFF800000;
123 const unsigned int ifNaN = 0x7F800001;
124 const unsigned int ifNeg0 = 0x80000000;
126 const unsigned int ifmax = 0x47600000;
127 const unsigned int ifmin = 0xC7600000;
129 fInf = bit_cast<float>(ifInf);
130 fNegInf = bit_cast<float>(ifNegInf);
131 fNaN = bit_cast<float>(ifNaN);
132 fNeg0 = bit_cast<float>(ifNeg0);
133 fmax = bit_cast<float>(ifmax);
134 fmin = bit_cast<float>(ifmin);
136 else if constexpr(is_double)
138 const unsigned long long ifInf = 0x7FF0000000000000ull;
139 const unsigned long long ifNegInf = 0xFFF0000000000000ull;
140 const unsigned long long ifNaN = 0x7FF0000000000001ull;
141 const unsigned long long ifNeg0 = 0x8000000000000000ull;
143 const unsigned long long ifmax = 0x40EC000000000000ull;
144 const unsigned long long ifmin = 0xC0EC000000000000ull;
146 fInf = bit_cast<double>(ifInf);
147 fNegInf = bit_cast<double>(ifNegInf);
148 fNaN = bit_cast<double>(ifNaN);
149 fNeg0 = bit_cast<double>(ifNeg0);
150 fmax = bit_cast<double>(ifmax);
151 fmin = bit_cast<double>(ifmin);
159 unsigned long long sign = x >> 7;
160 unsigned long long mantissa = x & ((1 << wm) - 1);
161 int exponent = (x & 0x7F) >> wm;
162 if constexpr(is_fnuz)
175 if constexpr(we == 4)
177 if((x & 0x7F) == 0x7F)
182 else if((x & 0x7C) == 0x7C)
188 return sign ? fmin : fmax;
190 return sign ? fNegInf : fInf;
196 typename std::conditional<
199 typename std::conditional<
sizeof(T) == 4,
unsigned int,
unsigned long long>::type>::type
202 if constexpr(we == 5 && is_half && !is_fnuz)
205 return bit_cast<T>(retval);
208 const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
213 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
215 int sh = 1 + __clz(mantissa) - (32 - wm);
217 int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
221 mantissa &= ((1ull << wm) - 1);
223 exponent += exp_low_cutoff - 1;
224 mantissa <<= wmo - wm;
229 mantissa |= 1 << wmo;
230 mantissa >>= 1 - exponent;
234 if constexpr(
sizeof(T) == 2)
235 retval = (sign << 15) | (exponent << 10) | mantissa;
236 else if constexpr(sizeof(T) == 4)
237 retval = (sign << 31) | (exponent << 23) | mantissa;
239 retval = (sign << 63) | (static_cast<
unsigned long long>(exponent) << 52) | mantissa;
244 #if CK_FP8_CVT_FAST_PATH
245 template <ck_fp8_
interpretation_t
interpret>
251 unsigned char i8val[4];
259 "Only FNUZ and OCP interpretations are supported");
264 return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
268 return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
272 template <ck_fp8_
interpretation_t
interpret>
275 const auto i16val = bit_cast<uint16_t>(v);
281 "Only FNUZ and OCP interpretations are supported");
286 return __builtin_amdgcn_cvt_pk_f32_fp8(i16val,
false);
290 return __builtin_amdgcn_cvt_pk_f32_bf8(i16val,
false);
306 static constexpr
unsigned int we = 4;
307 static constexpr
unsigned int wm = 3;
311 return (data == other.
data) && (fp8_impl::ocp_f8_is_nan(data) ==
false);
315 __host__ __device__
explicit operator float() const
317 __host__
explicit operator float() const
320 #if CK_OCP_FP8_CVT_FAST_PATH
321 return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
323 return fp8_impl::cast_from_f8<float, wm, we, false>(
329 __host__ __device__
explicit operator _Float16() const
331 __host__
explicit operator _Float16() const
334 #if CK_OCP_FP8_CVT_FAST_PATH
335 return static_cast<_Float16
>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
337 return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
352 static constexpr
unsigned int we = 5;
353 static constexpr
unsigned int wm = 2;
357 return (data == other.
data) && (fp8_impl::ocp_bf8_is_nan(data) ==
false);
361 __host__ __device__
explicit operator float() const
364 __host__
explicit operator float() const
367 #if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
368 return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
370 return fp8_impl::cast_from_f8<float, wm, we, false>(
376 __host__ __device__
explicit operator _Float16() const
378 __host__
explicit operator _Float16() const
381 #if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
382 return static_cast<_Float16
>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
384 return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
390 template <
typename T>
391 __host__ __device__
static inline constexpr
bool fp8_is_nan(T);
394 __host__ __device__
inline constexpr
bool fp8_is_nan(
f8_ocp_t a)
396 return fp8_impl::ocp_f8_is_nan(a.
data);
399 __host__ __device__
inline constexpr
bool fp8_is_nan(
bf8_ocp_t a)
401 return fp8_impl::ocp_bf8_is_nan(a.
data);
404 __host__ __device__
inline constexpr
bool fp8_is_nan(
f8_fnuz_t a)
406 return fp8_impl::fnuz_f8_is_nan(a);
409 __host__ __device__
inline constexpr
bool fp8_is_nan(
bf8_fnuz_t a)
411 return fp8_impl::fnuz_bf8_is_nan(a);
414 template <
typename T,
416 is_same_v<T, bf8_fnuz_t> || is_same_v<T, f8_fnuz_t>,
418 __host__ __device__
static inline constexpr
bool fp8_is_inf(T)
423 __host__ __device__
inline constexpr
bool fp8_is_inf(
bf8_ocp_t a)
425 return (a.
data & 0x7f) == 0x7c;
431 #define __assert_ocp_support(interp) \
433 if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
434 interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
436 __hip_assert(false && "type is unsupported by current target device"); \
439 #define __assert_fnuz_support(interp) \
441 if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
442 interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
444 __hip_assert(false && "type is unsupported by current target device"); \
448 __host__ __device__
static inline void
451 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
461 #if CK_FP8_CVT_FAST_PATH
464 template <ck_fp8_
interpretation_t
interpret,
bool saturate,
bool stochastic_rounding = false>
465 static __device__
fp8_storage_t cast_to_f8_from_f32(
float v,
unsigned int rng = 0)
472 unsigned char i8val[4];
475 unsigned int ival = 0;
478 if constexpr(saturate)
482 if((val.i32val & 0x7F800000) != 0x7F800000)
484 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
489 if((val.i32val & 0x7F800000) != 0x7F800000)
491 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
496 if((val.i32val & 0x7F800000) != 0x7F800000)
498 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
503 if constexpr(stochastic_rounding)
507 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
508 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0);
510 i8data = val.i8val[0];
516 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
false)
517 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
522 i8data = val.i8val[0];
531 template <
typename T,
int wm,
int we,
bool is_fnuz,
bool clip = false,
bool stoch = false>
532 __host__ __device__
static inline fp8_storage_t cast_to_f8(T _x,
unsigned int rng = 0)
534 constexpr
bool is_half = __hip_internal::is_same<T, _Float16>::value;
535 constexpr
bool is_float = __hip_internal::is_same<T, float>::value;
536 constexpr
bool is_double = __hip_internal::is_same<T, double>::value;
537 static_assert(is_half || is_float || is_double,
538 "Only half, float and double can be cast to f8");
540 constexpr
int mfmt = (
sizeof(T) == 8) ? 52 : ((
sizeof(T) == 4) ? 23 : 10);
542 using T_bitwise =
typename std::conditional<
545 typename std::conditional<
sizeof(T) == 4,
unsigned int,
unsigned long long>::type>::type;
546 T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
548 unsigned long long x{x_bitwise};
550 unsigned long long head, mantissa;
553 unsigned long long fInf, mask;
555 if constexpr(
sizeof(T) == 8)
557 head = x & 0xFFF0000000000000ull;
558 mantissa = x & 0xFFFFFFFFFFFFFull;
559 exponent = (head >> 52) & 0x7FF;
562 fInf = 0x7FF0000000000000ull;
563 mask = 0x7FFFFFFFFFFFFFFFull;
565 else if constexpr(
sizeof(T) == 4)
567 head = x & 0xFF800000;
568 mantissa = x & 0x7FFFFF;
569 exponent = (head >> 23) & 0xFF;
578 mantissa = x & 0x3FF;
579 exponent = (head >> 10) & 0x1F;
585 unsigned int signed_inf = 0;
586 unsigned int nan = 0;
587 if constexpr(is_fnuz)
589 signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
594 if constexpr(we == 4)
596 signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
600 signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
602 nan = (sign << 7) + 0x7f;
605 unsigned long long ifmax = 0;
606 if constexpr(
sizeof(T) == 8)
608 if constexpr(we == 5)
610 ifmax = 0x40EC000000000000ull;
614 if constexpr(is_fnuz)
616 ifmax = 0x406E000000000000ull;
620 ifmax = 0x407C000000000000ull;
624 else if(
sizeof(T) == 4)
626 if constexpr(we == 5)
632 if constexpr(is_fnuz)
644 if constexpr(we == 5)
650 if constexpr(is_fnuz)
661 if((x & fInf) == fInf)
663 if constexpr(is_fnuz)
666 return mantissa != 0 ? nan : signed_inf;
669 if((x & mask) > ifmax)
687 const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
688 const int f8_denormal_act_exponent = 1 - f8_bias;
693 int act_exponent, f8_exponent, exponent_diff;
704 act_exponent = exponent - bias + 1;
705 exponent_diff = f8_denormal_act_exponent -
710 act_exponent = exponent - bias;
711 if(act_exponent <= f8_denormal_act_exponent)
718 exponent_diff = f8_denormal_act_exponent - act_exponent;
726 mantissa += (1ull << mfmt);
729 bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
730 (1ull << (mfmt - wm + exponent_diff - 1));
738 if(exponent_diff > 0)
739 mantissa >>= exponent_diff;
740 else if(exponent_diff == -1)
741 mantissa <<= -exponent_diff;
742 bool implicit_one = mantissa & (1ull << mfmt);
746 (act_exponent + exponent_diff) + f8_bias - (implicit_one ? 0 : 1);
749 unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
751 mantissa & (1ull << (mfmt - wm));
753 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
758 if((1ull << mfmt) & mantissa)
765 if((1ull << (mfmt + 1)) & mantissa)
772 mantissa >>= (mfmt - wm);
775 const int max_exp = (1 << we) - 1;
776 if(f8_exponent > max_exp)
780 mantissa = (1 << wm) - 1;
781 f8_exponent = max_exp;
789 if(f8_exponent == 0 && mantissa == 0)
790 return is_fnuz ? 0 : (sign << 7);
791 mantissa &= (1 << wm) - 1;
792 return (sign << 7) | (f8_exponent << wm) | mantissa;
805 bool stochastic_rounding =
false>
806 #if CK_FP8_CVT_FAST_PATH
807 __host__ __device__
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
809 __is_interpret_supported(interp);
811 if constexpr(stochastic_rounding)
813 constexpr
int seed = 1254739;
814 #ifndef CK_CODE_GEN_RTC
815 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t
>(&f), f);
817 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&f), f);
820 return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
824 __host__ __device__
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
827 __host__
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
831 if constexpr(stochastic_rounding)
833 constexpr
int seed = 1254739;
834 #ifndef CK_CODE_GEN_RTC
835 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t
>(&f), f);
837 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&f), f);
843 return cast_to_f8<float,
848 stochastic_rounding>(f, rng);
852 return cast_to_f8<float,
857 stochastic_rounding>(f, rng);
861 return cast_to_f8<float,
866 stochastic_rounding>(f, rng);
870 return cast_to_f8<float,
875 stochastic_rounding>(f, rng);
879 __hip_assert(
false &&
"FP8 type is not supported by current target device");
896 bool stochastic_rounding =
false>
897 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
898 __host__ __device__
static inline fp8_storage_t cvt_half_t_to_fp8(
const _Float16 x)
900 __host__
static inline fp8_storage_t cvt_half_t_to_fp8(
const _Float16 x)
903 return cvt_float_to_fp8<interp, sat, stochastic_rounding>(
static_cast<float>(x));
909 template <
typename Y,
typename X>
917 fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
925 fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(x)};
933 fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
940 fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
945 template <
typename Y,
typename X>
953 fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
985 using f8_t = f8_ocp_t;
986 using bf8_t = bf8_ocp_t;
987 #define CK_FP8_TYPE_FNUZ 0
988 #define CK_FP8_TYPE_OCP 1
992 #define CK_FP8_TYPE_FNUZ 1
993 #define CK_FP8_TYPE_OCP 0
#define __assert_ocp_support(interp)
Definition: amd_ck_fp8.hpp:431
#define __assert_fnuz_support(interp)
Definition: amd_ck_fp8.hpp:439
float float2_t
Definition: amd_ck_fp8.hpp:67
fp8_storage_t fp8x2_storage_t
Definition: amd_ck_fp8.hpp:66
__host__ constexpr __device__ Y bit_cast(const X &x)
Definition: type.hpp:309
__host__ __device__ bf8_ocp_t f8_convert_sr< bf8_ocp_t, _Float16 >(_Float16 x)
Definition: amd_ck_fp8.hpp:977
__host__ __device__ bf8_ocp_t f8_convert_rne< bf8_ocp_t, _Float16 >(_Float16 x)
Definition: amd_ck_fp8.hpp:937
bf8_fnuz_t bf8_t
Definition: amd_ck_fp8.hpp:991
__host__ __device__ f8_ocp_t f8_convert_rne< f8_ocp_t, float >(float x)
Definition: amd_ck_fp8.hpp:914
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:990
ck_fp8_interpretation_t
Describes FP8 interpretation.
Definition: amd_ck_fp8.hpp:48
__host__ __device__ bf8_ocp_t f8_convert_sr< bf8_ocp_t, float >(float x)
Definition: amd_ck_fp8.hpp:959
__host__ constexpr __device__ Y f8_convert_rne(X x)
__host__ __device__ f8_ocp_t f8_convert_rne< f8_ocp_t, _Float16 >(_Float16 x)
Definition: amd_ck_fp8.hpp:930
unsigned _BitInt(8) bf8_fnuz_t
Definition: amd_ck_fp8.hpp:40
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
__host__ __device__ f8_ocp_t f8_convert_sr< f8_ocp_t, _Float16 >(_Float16 x)
Definition: amd_ck_fp8.hpp:968
_BitInt(8) f8_fnuz_t
Definition: amd_ck_fp8.hpp:39
__host__ __device__ bf8_ocp_t f8_convert_rne< bf8_ocp_t, float >(float x)
Definition: amd_ck_fp8.hpp:922
ck_saturation_t
Describes saturation behavior.
Definition: amd_ck_fp8.hpp:59
__host__ __device__ f8_ocp_t f8_convert_sr< f8_ocp_t, float >(float x)
Definition: amd_ck_fp8.hpp:950
__host__ constexpr __device__ Y f8_convert_sr(X x)
unsigned char fp8_storage_t
Definition: amd_ck_fp8.hpp:42
Definition: amd_ck_fp8.hpp:344
__host__ constexpr __device__ bool operator==(const bf8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:355
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:345
data_type data
Definition: amd_ck_fp8.hpp:346
static constexpr ck_fp8_interpretation_t default_interpret
Definition: amd_ck_fp8.hpp:349
static constexpr ck_saturation_t default_saturation
Definition: amd_ck_fp8.hpp:348
Definition: amd_ck_fp8.hpp:298
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:299
data_type data
Definition: amd_ck_fp8.hpp:300
__host__ constexpr __device__ bool operator==(const f8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:309
static constexpr ck_fp8_interpretation_t default_interpret
Definition: amd_ck_fp8.hpp:303
static constexpr ck_saturation_t default_saturation
Definition: amd_ck_fp8.hpp:302