13 #ifndef CK_USE_FNUZ_FP8
14 #define CK_USE_FNUZ_FP8 0
17 #ifndef CK_USE_OCP_FP8
18 #define CK_USE_OCP_FP8 0
21 #if(defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
22 #define CK_FP8_CVT_FAST_PATH 1
24 #define CK_FP8_CVT_FAST_PATH 0
27 #if(defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
28 #define CK_OCP_FP8_CVT_FAST_PATH 1
30 #define CK_OCP_FP8_CVT_FAST_PATH 0
40 __host__ __device__
explicit constexpr
f8_fnuz_t() =
default;
45 __host__ __device__
explicit constexpr
operator data_type()
const {
return m_data; }
53 __host__ __device__
explicit constexpr
bf8_fnuz_t() =
default;
58 __host__ __device__
explicit constexpr
operator data_type()
const {
return m_data; }
61 static_assert(1 ==
sizeof(f8_fnuz_t));
62 static_assert(1 ==
sizeof(bf8_fnuz_t));
89 typedef _Float16
half2_t __attribute__((ext_vector_type(2)));
90 typedef ushort
ushortx2_t __attribute__((ext_vector_type(2)));
91 typedef short shortx2_t __attribute__((ext_vector_type(2)));
92 typedef float float2_t __attribute__((ext_vector_type(2)));
94 __host__ __device__
static inline constexpr
bool fnuz_f8_is_nan(
f8_fnuz_t a)
96 return static_cast<unsigned char>(
a) == 0x80;
98 __host__ __device__
static inline constexpr
bool fnuz_bf8_is_nan(
bf8_fnuz_t a)
100 return static_cast<unsigned char>(
a) == 0x80;
103 __host__ __device__
static inline constexpr
bool ocp_f8_is_nan(
fp8_storage_t a)
105 return (
a & 0x7f) == 0x7f;
107 __host__ __device__
static inline constexpr
bool ocp_bf8_is_nan(
fp8_storage_t a)
109 return (
a & 0x7f) > 0x7c;
115 template <
typename T,
int wm,
int we,
bool is_fnuz,
bool clip = false>
116 __host__ __device__
static inline T cast_from_f8(
fp8_storage_t x)
121 static_assert(is_half || is_float || is_double,
"only half, float and double are supported");
123 constexpr
int weo = is_half ? 5 : (is_float ? 8 : 11);
124 constexpr
int wmo = is_half ? 10 : (is_float ? 23 : 52);
126 T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
127 if constexpr(is_half)
129 const unsigned short int ihInf = 0x7C00;
130 const unsigned short int ihNegInf = 0xFC00;
131 const unsigned short int ihNaN = 0x7C01;
132 const unsigned short int ihNeg0 = 0x8000;
134 const unsigned short int ifmax = 0x7B00;
135 const unsigned short int ifmin = 0xFB00;
137 fInf = bit_cast<_Float16>(ihInf);
138 fNegInf = bit_cast<_Float16>(ihNegInf);
139 fNaN = bit_cast<_Float16>(ihNaN);
140 fNeg0 = bit_cast<_Float16>(ihNeg0);
141 fmax = bit_cast<_Float16>(ifmax);
142 fmin = bit_cast<_Float16>(ifmin);
144 else if constexpr(is_float)
146 const unsigned int ifInf = 0x7F800000;
147 const unsigned int ifNegInf = 0xFF800000;
148 const unsigned int ifNaN = 0x7F800001;
149 const unsigned int ifNeg0 = 0x80000000;
151 const unsigned int ifmax = 0x47600000;
152 const unsigned int ifmin = 0xC7600000;
154 fInf = bit_cast<float>(ifInf);
155 fNegInf = bit_cast<float>(ifNegInf);
156 fNaN = bit_cast<float>(ifNaN);
157 fNeg0 = bit_cast<float>(ifNeg0);
158 fmax = bit_cast<float>(ifmax);
159 fmin = bit_cast<float>(ifmin);
161 else if constexpr(is_double)
163 const unsigned long long ifInf = 0x7FF0000000000000ull;
164 const unsigned long long ifNegInf = 0xFFF0000000000000ull;
165 const unsigned long long ifNaN = 0x7FF0000000000001ull;
166 const unsigned long long ifNeg0 = 0x8000000000000000ull;
168 const unsigned long long ifmax = 0x40EC000000000000ull;
169 const unsigned long long ifmin = 0xC0EC000000000000ull;
171 fInf = bit_cast<double>(ifInf);
172 fNegInf = bit_cast<double>(ifNegInf);
173 fNaN = bit_cast<double>(ifNaN);
174 fNeg0 = bit_cast<double>(ifNeg0);
175 fmax = bit_cast<double>(ifmax);
176 fmin = bit_cast<double>(ifmin);
184 unsigned long long sign = x >> 7;
185 unsigned long long mantissa = x & ((1 << wm) - 1);
186 int exponent = (x & 0x7F) >> wm;
187 if constexpr(is_fnuz)
200 if constexpr(we == 4)
202 if((x & 0x7F) == 0x7F)
207 else if((x & 0x7C) == 0x7C)
213 return sign ? fmin : fmax;
215 return sign ? fNegInf : fInf;
227 if constexpr(we == 5 && is_half && !is_fnuz)
230 return bit_cast<T>(retval);
233 const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
238 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
240 int sh = 1 + __clz(mantissa) - (32 - wm);
242 int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
246 mantissa &= ((1ull << wm) - 1);
248 exponent += exp_low_cutoff - 1;
249 mantissa <<= wmo - wm;
254 mantissa |= 1 << wmo;
255 mantissa >>= 1 - exponent;
259 if constexpr(
sizeof(T) == 2)
260 retval = (sign << 15) | (exponent << 10) | mantissa;
261 else if constexpr(sizeof(T) == 4)
262 retval = (sign << 31) | (exponent << 23) | mantissa;
264 retval = (sign << 63) | (static_cast<
unsigned long long>(exponent) << 52) | mantissa;
269 #if CK_FP8_CVT_FAST_PATH
270 template <ck_fp8_
interpretation_t
interpret>
271 static __host__ __device__
float cast_to_f32_from_f8(
fp8_storage_t v)
276 unsigned char i8val[4];
284 "Only FNUZ and OCP interpretations are supported");
289 return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
293 return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
297 template <ck_fp8_
interpretation_t
interpret>
300 const auto i16val = bit_cast<uint16_t>(v);
306 "Only FNUZ and OCP interpretations are supported");
311 return __builtin_amdgcn_cvt_pk_f32_fp8(i16val,
false);
315 return __builtin_amdgcn_cvt_pk_f32_bf8(i16val,
false);
331 static constexpr
unsigned int we = 4;
332 static constexpr
unsigned int wm = 3;
336 return (data == other.
data) && (fp8_impl::ocp_f8_is_nan(data) ==
false);
340 __host__ __device__
explicit operator float() const
342 __host__
explicit operator float() const
345 #if CK_OCP_FP8_CVT_FAST_PATH
346 return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
348 return fp8_impl::cast_from_f8<float, wm, we, false>(
354 __host__ __device__
explicit operator _Float16() const
356 __host__
explicit operator _Float16() const
359 #if CK_OCP_FP8_CVT_FAST_PATH
360 return static_cast<_Float16
>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
362 return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
377 static constexpr
unsigned int we = 5;
378 static constexpr
unsigned int wm = 2;
382 return (data == other.
data) && (fp8_impl::ocp_bf8_is_nan(data) ==
false);
386 __host__ __device__
explicit operator float() const
389 __host__
explicit operator float() const
392 #if defined(__gfx950__) || defined(__gfx12__)
393 return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
395 return fp8_impl::cast_from_f8<float, wm, we, false>(
401 __host__ __device__
explicit operator _Float16() const
403 __host__
explicit operator _Float16() const
406 #if defined(__gfx950__) || defined(__gfx12__)
407 return static_cast<_Float16
>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
409 return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
415 template <
typename T>
416 __host__ __device__
static inline constexpr
bool fp8_is_nan(T);
419 __host__ __device__
inline constexpr
bool fp8_is_nan(
f8_ocp_t a)
421 return fp8_impl::ocp_f8_is_nan(
a.data);
424 __host__ __device__
inline constexpr
bool fp8_is_nan(
bf8_ocp_t a)
426 return fp8_impl::ocp_bf8_is_nan(
a.data);
429 __host__ __device__
inline constexpr
bool fp8_is_nan(
f8_fnuz_t a)
431 return fp8_impl::fnuz_f8_is_nan(
a);
434 __host__ __device__
inline constexpr
bool fp8_is_nan(
bf8_fnuz_t a)
436 return fp8_impl::fnuz_bf8_is_nan(
a);
439 template <
typename T,
441 is_same_v<T, bf8_fnuz_t> || is_same_v<T, f8_fnuz_t>,
443 __host__ __device__
static inline constexpr
bool fp8_is_inf(T)
448 __host__ __device__
inline constexpr
bool fp8_is_inf(
bf8_ocp_t a)
450 return (
a.data & 0x7f) == 0x7c;
456 #define __fp8_impl_assert_ocp_support(interp) \
458 if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
459 interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
461 __hip_assert(false && "type is unsupported by current target device"); \
464 #define __fp8_impl_assert_fnuz_support(interp) \
466 if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
467 interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
469 __hip_assert(false && "type is unsupported by current target device"); \
473 __host__ __device__
static inline void
476 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
486 #if defined(__gfx950__)
489 bool stochastic_rounding =
false,
492 static __device__
fp8_storage_t cast_to_f8_from_f16(_Float16 v,
unsigned int rng = 0)
501 constexpr
unsigned int i32val = 0;
504 if constexpr(saturate)
506 if((val.i32val & 0x7FFF) != 0x7FFF)
508 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
513 __builtin_amdgcn_cvt_scalef32_sr_fp8_f16(i32val, val.half_vec[0], rng, 1.f, 0);
520 bool stochastic_rounding =
false,
527 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
528 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
533 bool stochastic_rounding =
false,
536 static __device__
fp8_storage_t cast_to_f8_from_f16(_Float16 v,
unsigned int rng = 0)
545 constexpr
unsigned int i32val = 0;
548 if constexpr(saturate)
550 if((val.i32val & 0x7FFF) != 0x7FFF)
552 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
557 __builtin_amdgcn_cvt_scalef32_sr_bf8_f16(i32val, val.half_vec[0], rng, 1.f, 0);
564 bool stochastic_rounding =
false,
571 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
572 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
577 bool stochastic_rounding =
false,
580 static __device__
fp8_storage_t cast_to_f8_from_f16(_Float16 v,
unsigned int rng = 0)
595 if constexpr(saturate)
597 if((val.i32val & 0x7FFF) != 0x7FFF)
599 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
604 __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, 1.f, 0);
611 bool stochastic_rounding =
false,
616 #if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
618 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
619 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
633 if constexpr(saturate)
635 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
637 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
639 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
641 val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 448.0, -448.0);
646 __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, 1.f, 0);
654 bool stochastic_rounding =
false,
657 static __device__
fp8_storage_t cast_to_f8_from_f16(_Float16 v,
unsigned int rng = 0)
672 if constexpr(saturate)
674 if((val.i32val & 0x7FFF) != 0x7FFF)
676 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
681 __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, 1.f, 0);
688 bool stochastic_rounding =
false,
693 #if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
695 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
696 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
710 if constexpr(saturate)
712 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
714 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
716 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
718 val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 57344.0, -57344.0);
723 __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, 1.f, 0);
731 bool stochastic_rounding =
false,
734 static __device__
fp8_storage_t cast_to_f8_from_bf16(ushort v,
unsigned int rng = 0)
743 constexpr
unsigned int i32val = 0;
744 val.bhalf_vec[0] = v;
746 if constexpr(saturate)
748 if((val.i32val & 0x7FFF) != 0x7FFF)
751 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
752 bit_cast<float>(
uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
757 val.i32val = __builtin_amdgcn_cvt_scalef32_sr_fp8_bf16(
758 i32val, val.bhalf_vec[0], rng, 1.f, 0);
765 bool stochastic_rounding =
false,
772 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
773 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
778 bool stochastic_rounding =
false,
781 static __device__
fp8_storage_t cast_to_f8_from_bf16(ushort v,
unsigned int rng = 0)
790 constexpr
unsigned int i32val = 0;
791 val.bhalf_vec[0] = v;
793 if constexpr(saturate)
795 if((val.i32val & 0x7FFF) != 0x7FFF)
797 val.bhalf_vec[0] = ushort(
798 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
799 bit_cast<float>(
uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
804 val.i32val = __builtin_amdgcn_cvt_scalef32_sr_bf8_bf16(
805 i32val, val.bhalf_vec[0], rng, 1.f, 0);
812 bool stochastic_rounding =
false,
819 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
820 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
825 bool stochastic_rounding =
false,
828 static __device__
fp8_storage_t cast_to_f8_from_bf16(ushort v,
unsigned int rng = 0)
841 val.bhalf_vec[0] = v;
843 if constexpr(saturate)
845 if((val.i32val & 0x7FFF) != 0x7FFF)
848 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
849 bit_cast<float>(
uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
855 __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, 1.f, 0);
862 bool stochastic_rounding =
false,
867 #if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
869 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
870 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
884 if constexpr(saturate)
886 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
889 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
890 bit_cast<float>(
uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
893 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
896 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
897 bit_cast<float>(
uint32_t{val.bhalf_vec[1]} << 16), 448.0, -448.0)) >>
903 __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, 1.f, 0);
911 bool stochastic_rounding =
false,
914 static __device__
fp8_storage_t cast_to_f8_from_bf16(ushort v,
unsigned int rng = 0)
927 val.bhalf_vec[0] = v;
929 if constexpr(saturate)
931 if((val.i32val & 0x7FFF) != 0x7FFF)
933 val.bhalf_vec[0] = ushort(
934 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
935 bit_cast<float>(
uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
941 __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, 1.f, 0);
948 bool stochastic_rounding =
false,
965 if constexpr(saturate)
967 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
969 val.bhalf_vec[0] = ushort(
970 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
971 bit_cast<float>(
uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
974 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
976 val.bhalf_vec[1] = ushort(
977 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
978 bit_cast<float>(
uint32_t{val.bhalf_vec[1]} << 16), 57344.0, -57344.0)) >>
984 __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, 1.f, 0);
990 #if CK_FP8_CVT_FAST_PATH
993 template <ck_fp8_
interpretation_t
interpret,
bool saturate,
bool stochastic_rounding = false>
994 static __device__
fp8_storage_t cast_to_f8_from_f32(
float v,
unsigned int rng = 0)
1000 unsigned int i32val;
1001 unsigned char i8val[4];
1004 unsigned int ival = 0;
1007 if constexpr(saturate)
1011 if((val.i32val & 0x7F800000) != 0x7F800000)
1013 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
1018 if((val.i32val & 0x7F800000) != 0x7F800000)
1020 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
1025 if((val.i32val & 0x7F800000) != 0x7F800000)
1027 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
1032 if constexpr(stochastic_rounding)
1036 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
1037 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0);
1039 i8data = val.i8val[0];
1045 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
false)
1046 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
1051 i8data = val.i8val[0];
1056 template <ck_fp8_
interpretation_t
interpret,
bool saturate,
bool stochastic_rounding = false>
1059 if constexpr(stochastic_rounding)
1063 cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[0], rng),
1064 cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[1], rng)};
1071 unsigned int i32val;
1072 unsigned char i8val[4];
1078 unsigned int ival = 0;
1080 if constexpr(saturate)
1084 if((val0.i32val & 0x7F800000) != 0x7F800000)
1086 val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 240.0, -240.0);
1088 if((val1.i32val & 0x7F800000) != 0x7F800000)
1090 val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 240.0, -240.0);
1095 if((val0.i32val & 0x7F800000) != 0x7F800000)
1097 val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 448.0, -448.0);
1099 if((val1.i32val & 0x7F800000) != 0x7F800000)
1101 val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 448.0, -448.0);
1106 if((val0.i32val & 0x7F800000) != 0x7F800000)
1108 val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 57344.0, -57344.0);
1110 if((val1.i32val & 0x7F800000) != 0x7F800000)
1112 val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 57344.0, -57344.0);
1121 ival = __builtin_amdgcn_cvt_pk_fp8_f32(val0.fval, val1.fval, ival,
false);
1125 ival = __builtin_amdgcn_cvt_pk_bf8_f32(val0.fval, val1.fval, ival,
false);
1138 template <
typename T,
int wm,
int we,
bool is_fnuz,
bool clip = false,
bool stoch = false>
1139 __host__ __device__
static inline fp8_storage_t cast_to_f8(T _x,
unsigned int rng = 0)
1144 static_assert(is_half || is_float || is_double,
1145 "Only half, float and double can be cast to f8");
1147 constexpr
int mfmt = (
sizeof(T) == 8) ? 52 : ((
sizeof(T) == 4) ? 23 : 10);
1153 T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
1155 unsigned long long x{x_bitwise};
1157 unsigned long long head, mantissa;
1160 unsigned long long fInf, mask;
1162 if constexpr(
sizeof(T) == 8)
1164 head = x & 0xFFF0000000000000ull;
1165 mantissa = x & 0xFFFFFFFFFFFFFull;
1166 exponent = (head >> 52) & 0x7FF;
1169 fInf = 0x7FF0000000000000ull;
1170 mask = 0x7FFFFFFFFFFFFFFFull;
1172 else if constexpr(
sizeof(T) == 4)
1174 head = x & 0xFF800000;
1175 mantissa = x & 0x7FFFFF;
1176 exponent = (head >> 23) & 0xFF;
1185 mantissa = x & 0x3FF;
1186 exponent = (head >> 10) & 0x1F;
1192 unsigned int signed_inf = 0;
1193 unsigned int nan = 0;
1194 if constexpr(is_fnuz)
1196 signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
1201 if constexpr(we == 4)
1203 signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
1207 signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
1209 nan = (sign << 7) + 0x7f;
1212 unsigned long long ifmax = 0;
1213 if constexpr(
sizeof(T) == 8)
1215 if constexpr(we == 5)
1217 ifmax = 0x40EC000000000000ull;
1221 if constexpr(is_fnuz)
1223 ifmax = 0x406E000000000000ull;
1227 ifmax = 0x407C000000000000ull;
1231 else if(
sizeof(T) == 4)
1233 if constexpr(we == 5)
1239 if constexpr(is_fnuz)
1251 if constexpr(we == 5)
1257 if constexpr(is_fnuz)
1268 if((x & fInf) == fInf)
1270 if constexpr(is_fnuz)
1273 return mantissa != 0 ? nan : signed_inf;
1276 if((x & mask) > ifmax)
1294 const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
1295 const int f8_denormal_act_exponent = 1 - f8_bias;
1300 int act_exponent, f8_exponent, exponent_diff;
1311 act_exponent = exponent - bias + 1;
1312 exponent_diff = f8_denormal_act_exponent -
1317 act_exponent = exponent - bias;
1318 if(act_exponent <= f8_denormal_act_exponent)
1325 exponent_diff = f8_denormal_act_exponent - act_exponent;
1333 mantissa += (1ull << mfmt);
1336 bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
1337 (1ull << (mfmt - wm + exponent_diff - 1));
1345 if(exponent_diff > 0)
1346 mantissa >>= exponent_diff;
1347 else if(exponent_diff == -1)
1348 mantissa <<= -exponent_diff;
1349 bool implicit_one = mantissa & (1ull << mfmt);
1353 (act_exponent + exponent_diff) + f8_bias - (implicit_one ? 0 : 1);
1356 unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
1358 mantissa & (1ull << (mfmt - wm));
1360 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
1363 if(f8_exponent == 0)
1365 if((1ull << mfmt) & mantissa)
1372 if((1ull << (mfmt + 1)) & mantissa)
1379 mantissa >>= (mfmt - wm);
1382 const int max_exp = (1 << we) - 1;
1383 if(f8_exponent > max_exp)
1387 mantissa = (1 << wm) - 1;
1388 f8_exponent = max_exp;
1396 if(f8_exponent == 0 && mantissa == 0)
1397 return is_fnuz ? 0 : (sign << 7);
1398 mantissa &= (1 << wm) - 1;
1399 return (sign << 7) | (f8_exponent << wm) | mantissa;
1413 bool stochastic_rounding =
false>
1414 #if CK_FP8_CVT_FAST_PATH
1415 __host__ __device__
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
1417 __is_interpret_supported(interp);
1419 if constexpr(stochastic_rounding)
1421 #if defined(__gfx950__)
1423 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1426 constexpr
int seed = 1254739;
1427 #ifndef CK_CODE_GEN_RTC
1428 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&f), f);
1430 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&f), f);
1434 return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1438 __host__ __device__
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
1441 __host__
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
1445 if constexpr(stochastic_rounding)
1447 #if defined(__gfx950__)
1449 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1452 constexpr
int seed = 1254739;
1453 #ifndef CK_CODE_GEN_RTC
1454 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&f), f);
1456 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&f), f);
1463 return cast_to_f8<float,
1468 stochastic_rounding>(f, rng);
1472 return cast_to_f8<float,
1477 stochastic_rounding>(f, rng);
1481 return cast_to_f8<float,
1486 stochastic_rounding>(f, rng);
1490 return cast_to_f8<float,
1495 stochastic_rounding>(f, rng);
1499 __hip_assert(
false &&
"FP8 type is not supported by current target device");
1516 bool stochastic_rounding =
false>
1517 #if CK_FP8_CVT_FAST_PATH
1520 __is_interpret_supported(interp);
1522 if constexpr(stochastic_rounding)
1524 #if defined(__gfx950__)
1526 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1529 constexpr
int seed = 1254739;
1530 #ifndef CK_CODE_GEN_RTC
1531 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&f), f[0]);
1533 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&f), f[0]);
1537 return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1547 return fp8x2_storage_t{cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[0]),
1548 cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[1])};
1563 bool stochastic_rounding =
false>
1564 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1565 __host__ __device__
static inline fp8_storage_t cvt_half_t_to_fp8(
const _Float16 x)
1567 __host__
static inline fp8_storage_t cvt_half_t_to_fp8(
const _Float16 x)
1571 __is_interpret_supported(interp);
1573 if constexpr(stochastic_rounding)
1575 #if defined(__gfx950__)
1577 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1580 constexpr
int seed = 1254739;
1581 #ifndef CK_CODE_GEN_RTC
1582 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&x), x);
1584 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&x), x);
1588 #if defined(__gfx950__)
1589 return cast_to_f8_from_f16<interp,
1591 stochastic_rounding>(x, rng);
1594 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1595 static_cast<float>(x));
1611 bool stochastic_rounding =
false>
1612 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1619 __is_interpret_supported(interp);
1621 if constexpr(stochastic_rounding)
1623 #if defined(__gfx950__)
1625 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1628 constexpr
int seed = 1254739;
1629 #ifndef CK_CODE_GEN_RTC
1630 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&x), x[0]);
1632 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&x), x[0]);
1636 #if defined(__gfx950__)
1637 return cast_to_f8_from_f16<interp,
1639 stochastic_rounding>(x, rng);
1642 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1643 float2_t{
static_cast<float>(x[0]),
static_cast<float>(x[1])});
1659 bool stochastic_rounding =
false>
1660 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1661 __host__ __device__
static inline fp8_storage_t cvt_bhalf_t_to_fp8(
const ushort x)
1663 __host__
static inline fp8_storage_t cvt_bhalf_t_to_fp8(
const ushort x)
1667 __is_interpret_supported(interp);
1669 if constexpr(stochastic_rounding)
1671 #if defined(__gfx950__)
1673 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1676 constexpr
int seed = 1254739;
1677 #ifndef CK_CODE_GEN_RTC
1678 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&x),
1679 static_cast<float>(x));
1681 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&x),
static_cast<float>(x));
1685 #if defined(__gfx950__)
1686 return cast_to_f8_from_bf16<interp,
1688 stochastic_rounding>(x, rng);
1691 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1692 bit_cast<float>(
uint32_t{x} << 16));
1708 bool stochastic_rounding =
false>
1709 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1715 #if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
1716 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1718 bit_cast<float>(
uint32_t{x[1]} << 16)});
1721 __is_interpret_supported(interp);
1723 if constexpr(stochastic_rounding)
1725 #if defined(__gfx950__)
1727 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1730 constexpr
int seed = 1254739;
1731 #ifndef CK_CODE_GEN_RTC
1732 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&x),
1733 static_cast<float>(x[0]));
1735 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&x),
1736 static_cast<float>(x[0]));
1740 #if defined(__gfx950__)
1741 return cast_to_f8_from_bf16<interp,
1743 stochastic_rounding>(x, rng);
1746 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1748 bit_cast<float>(
uint32_t{x[1]} << 16)});
1757 using f8_t = f8_ocp_t;
1758 using bf8_t = bf8_ocp_t;
1759 #define CK_FP8_TYPE_FNUZ 0
1760 #define CK_FP8_TYPE_OCP 1
1764 #define CK_FP8_TYPE_FNUZ 1
1765 #define CK_FP8_TYPE_OCP 0
#define __fp8_impl_assert_fnuz_support(interp)
Definition: amd_ck_fp8.hpp:464
#define __fp8_impl_assert_ocp_support(interp)
Definition: amd_ck_fp8.hpp:456
ushort ushortx2_t
Definition: amd_ck_fp8.hpp:90
short shortx2_t
Definition: amd_ck_fp8.hpp:91
float float2_t
Definition: amd_ck_fp8.hpp:92
fp8_storage_t fp8x2_storage_t
Definition: amd_ck_fp8.hpp:88
_Float16 half2_t
Definition: amd_ck_fp8.hpp:89
__host__ constexpr __device__ Y bit_cast(const X &x)
Definition: type.hpp:306
bf8_fnuz_t bf8_t
Definition: amd_ck_fp8.hpp:1763
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1762
ck_fp8_interpretation_t
Describes FP8 interpretation.
Definition: amd_ck_fp8.hpp:70
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:43
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
ck_saturation_t
Describes saturation behavior.
Definition: amd_ck_fp8.hpp:81
unsigned char fp8_storage_t
Definition: amd_ck_fp8.hpp:61
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
_W64 unsigned int uintptr_t
Definition: stdint.h:165
unsigned int uint32_t
Definition: stdint.h:126
Definition: amd_ck_fp8.hpp:49
__host__ constexpr __device__ bf8_fnuz_t(data_type in_data)
Definition: amd_ck_fp8.hpp:52
data_type m_data
Definition: amd_ck_fp8.hpp:51
unsigned char data_type
Definition: amd_ck_fp8.hpp:50
__host__ constexpr __device__ bf8_fnuz_t()=default
__host__ __device__ constexpr bool operator==(bf8_fnuz_t other) const
Definition: amd_ck_fp8.hpp:54
Definition: amd_ck_fp8.hpp:369
__host__ constexpr __device__ bool operator==(const bf8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:380
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:370
data_type data
Definition: amd_ck_fp8.hpp:371
Definition: amd_ck_fp8.hpp:36
data_type m_data
Definition: amd_ck_fp8.hpp:38
__host__ constexpr __device__ f8_fnuz_t()=default
__host__ __device__ constexpr bool operator==(f8_fnuz_t other) const
Definition: amd_ck_fp8.hpp:41
unsigned char data_type
Definition: amd_ck_fp8.hpp:37
__host__ constexpr __device__ f8_fnuz_t(data_type in_data)
Definition: amd_ck_fp8.hpp:39
Definition: amd_ck_fp8.hpp:323
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:324
data_type data
Definition: amd_ck_fp8.hpp:325
__host__ constexpr __device__ bool operator==(const f8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:334