13 #ifndef CK_USE_FNUZ_FP8
14 #define CK_USE_FNUZ_FP8 0
17 #ifndef CK_USE_OCP_FP8
18 #define CK_USE_OCP_FP8 0
21 #if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \
22 __HIP_DEVICE_COMPILE__
23 #define CK_FP8_CVT_FAST_PATH 1
25 #define CK_FP8_CVT_FAST_PATH 0
28 #if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
29 #define CK_OCP_FP8_CVT_FAST_PATH 1
31 #define CK_OCP_FP8_CVT_FAST_PATH 0
64 typedef _Float16
half2_t __attribute__((ext_vector_type(2)));
65 typedef ushort
ushortx2_t __attribute__((ext_vector_type(2)));
66 typedef short shortx2_t __attribute__((ext_vector_type(2)));
67 typedef float float2_t __attribute__((ext_vector_type(2)));
69 __host__ __device__
static inline constexpr
bool fnuz_f8_is_nan(
f8_fnuz_t a)
71 return static_cast<unsigned char>(
a) == 0x80;
73 __host__ __device__
static inline constexpr
bool fnuz_bf8_is_nan(
bf8_fnuz_t a)
75 return static_cast<unsigned char>(
a) == 0x80;
78 __host__ __device__
static inline constexpr
bool ocp_f8_is_nan(
fp8_storage_t a)
80 return (
a & 0x7f) == 0x7f;
82 __host__ __device__
static inline constexpr
bool ocp_bf8_is_nan(
fp8_storage_t a)
84 return (
a & 0x7f) > 0x7c;
90 template <
typename T,
int wm,
int we,
bool is_fnuz,
bool clip = false>
91 __host__ __device__
static inline T cast_from_f8(
fp8_storage_t x)
96 static_assert(is_half || is_float || is_double,
"only half, float and double are supported");
98 constexpr
int weo = is_half ? 5 : (is_float ? 8 : 11);
99 constexpr
int wmo = is_half ? 10 : (is_float ? 23 : 52);
101 T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
102 if constexpr(is_half)
104 const unsigned short int ihInf = 0x7C00;
105 const unsigned short int ihNegInf = 0xFC00;
106 const unsigned short int ihNaN = 0x7C01;
107 const unsigned short int ihNeg0 = 0x8000;
109 const unsigned short int ifmax = 0x7B00;
110 const unsigned short int ifmin = 0xFB00;
112 fInf = bit_cast<_Float16>(ihInf);
113 fNegInf = bit_cast<_Float16>(ihNegInf);
114 fNaN = bit_cast<_Float16>(ihNaN);
115 fNeg0 = bit_cast<_Float16>(ihNeg0);
116 fmax = bit_cast<_Float16>(ifmax);
117 fmin = bit_cast<_Float16>(ifmin);
119 else if constexpr(is_float)
121 const unsigned int ifInf = 0x7F800000;
122 const unsigned int ifNegInf = 0xFF800000;
123 const unsigned int ifNaN = 0x7F800001;
124 const unsigned int ifNeg0 = 0x80000000;
126 const unsigned int ifmax = 0x47600000;
127 const unsigned int ifmin = 0xC7600000;
129 fInf = bit_cast<float>(ifInf);
130 fNegInf = bit_cast<float>(ifNegInf);
131 fNaN = bit_cast<float>(ifNaN);
132 fNeg0 = bit_cast<float>(ifNeg0);
133 fmax = bit_cast<float>(ifmax);
134 fmin = bit_cast<float>(ifmin);
136 else if constexpr(is_double)
138 const unsigned long long ifInf = 0x7FF0000000000000ull;
139 const unsigned long long ifNegInf = 0xFFF0000000000000ull;
140 const unsigned long long ifNaN = 0x7FF0000000000001ull;
141 const unsigned long long ifNeg0 = 0x8000000000000000ull;
143 const unsigned long long ifmax = 0x40EC000000000000ull;
144 const unsigned long long ifmin = 0xC0EC000000000000ull;
146 fInf = bit_cast<double>(ifInf);
147 fNegInf = bit_cast<double>(ifNegInf);
148 fNaN = bit_cast<double>(ifNaN);
149 fNeg0 = bit_cast<double>(ifNeg0);
150 fmax = bit_cast<double>(ifmax);
151 fmin = bit_cast<double>(ifmin);
159 unsigned long long sign = x >> 7;
160 unsigned long long mantissa = x & ((1 << wm) - 1);
161 int exponent = (x & 0x7F) >> wm;
162 if constexpr(is_fnuz)
175 if constexpr(we == 4)
177 if((x & 0x7F) == 0x7F)
182 else if((x & 0x7C) == 0x7C)
188 return sign ? fmin : fmax;
190 return sign ? fNegInf : fInf;
202 if constexpr(we == 5 && is_half && !is_fnuz)
205 return bit_cast<T>(retval);
208 const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
213 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
215 int sh = 1 + __clz(mantissa) - (32 - wm);
217 int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
221 mantissa &= ((1ull << wm) - 1);
223 exponent += exp_low_cutoff - 1;
224 mantissa <<= wmo - wm;
229 mantissa |= 1 << wmo;
230 mantissa >>= 1 - exponent;
234 if constexpr(
sizeof(T) == 2)
235 retval = (sign << 15) | (exponent << 10) | mantissa;
236 else if constexpr(sizeof(T) == 4)
237 retval = (sign << 31) | (exponent << 23) | mantissa;
239 retval = (sign << 63) | (static_cast<
unsigned long long>(exponent) << 52) | mantissa;
244 #if CK_FP8_CVT_FAST_PATH
245 template <ck_fp8_
interpretation_t
interpret>
246 static __host__ __device__
float cast_to_f32_from_f8(
fp8_storage_t v)
251 unsigned char i8val[4];
259 "Only FNUZ and OCP interpretations are supported");
264 return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
268 return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
272 template <ck_fp8_
interpretation_t
interpret>
275 const auto i16val = bit_cast<uint16_t>(v);
281 "Only FNUZ and OCP interpretations are supported");
286 return __builtin_amdgcn_cvt_pk_f32_fp8(i16val,
false);
290 return __builtin_amdgcn_cvt_pk_f32_bf8(i16val,
false);
306 static constexpr
unsigned int we = 4;
307 static constexpr
unsigned int wm = 3;
311 return (data == other.
data) && (fp8_impl::ocp_f8_is_nan(data) ==
false);
315 __host__ __device__
explicit operator float() const
317 __host__
explicit operator float() const
320 #if CK_OCP_FP8_CVT_FAST_PATH
321 return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
323 return fp8_impl::cast_from_f8<float, wm, we, false>(
329 __host__ __device__
explicit operator _Float16() const
331 __host__
explicit operator _Float16() const
334 #if CK_OCP_FP8_CVT_FAST_PATH
335 return static_cast<_Float16
>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
337 return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
352 static constexpr
unsigned int we = 5;
353 static constexpr
unsigned int wm = 2;
357 return (data == other.
data) && (fp8_impl::ocp_bf8_is_nan(data) ==
false);
361 __host__ __device__
explicit operator float() const
364 __host__
explicit operator float() const
367 #if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
368 return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
370 return fp8_impl::cast_from_f8<float, wm, we, false>(
376 __host__ __device__
explicit operator _Float16() const
378 __host__
explicit operator _Float16() const
381 #if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
382 return static_cast<_Float16
>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
384 return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
390 template <
typename T>
391 __host__ __device__
static inline constexpr
bool fp8_is_nan(T);
394 __host__ __device__
inline constexpr
bool fp8_is_nan(
f8_ocp_t a)
396 return fp8_impl::ocp_f8_is_nan(
a.data);
399 __host__ __device__
inline constexpr
bool fp8_is_nan(
bf8_ocp_t a)
401 return fp8_impl::ocp_bf8_is_nan(
a.data);
404 __host__ __device__
inline constexpr
bool fp8_is_nan(
f8_fnuz_t a)
406 return fp8_impl::fnuz_f8_is_nan(
a);
409 __host__ __device__
inline constexpr
bool fp8_is_nan(
bf8_fnuz_t a)
411 return fp8_impl::fnuz_bf8_is_nan(
a);
414 template <
typename T,
416 is_same_v<T, bf8_fnuz_t> || is_same_v<T, f8_fnuz_t>,
418 __host__ __device__
static inline constexpr
bool fp8_is_inf(T)
423 __host__ __device__
inline constexpr
bool fp8_is_inf(
bf8_ocp_t a)
425 return (
a.data & 0x7f) == 0x7c;
431 #define __fp8_impl_assert_ocp_support(interp) \
433 if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
434 interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
436 __hip_assert(false && "type is unsupported by current target device"); \
439 #define __fp8_impl_assert_fnuz_support(interp) \
441 if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
442 interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
444 __hip_assert(false && "type is unsupported by current target device"); \
448 __host__ __device__
static inline void
451 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
461 #if defined(__gfx950__)
464 bool stochastic_rounding =
false,
467 static __device__
fp8_storage_t cast_to_f8_from_f16(_Float16 v,
unsigned int rng = 0)
476 constexpr
unsigned int i32val = 0;
479 if constexpr(saturate)
481 if((val.i32val & 0x7FFF) != 0x7FFF)
483 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
488 __builtin_amdgcn_cvt_scalef32_sr_fp8_f16(i32val, val.half_vec[0], rng, 1.f, 0);
495 bool stochastic_rounding =
false,
502 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
503 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
508 bool stochastic_rounding =
false,
511 static __device__
fp8_storage_t cast_to_f8_from_f16(_Float16 v,
unsigned int rng = 0)
520 constexpr
unsigned int i32val = 0;
523 if constexpr(saturate)
525 if((val.i32val & 0x7FFF) != 0x7FFF)
527 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
532 __builtin_amdgcn_cvt_scalef32_sr_bf8_f16(i32val, val.half_vec[0], rng, 1.f, 0);
539 bool stochastic_rounding =
false,
546 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
547 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
552 bool stochastic_rounding =
false,
555 static __device__
fp8_storage_t cast_to_f8_from_f16(_Float16 v,
unsigned int rng = 0)
570 if constexpr(saturate)
572 if((val.i32val & 0x7FFF) != 0x7FFF)
574 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
579 __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, 1.f, 0);
586 bool stochastic_rounding =
false,
591 #if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
593 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
594 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
608 if constexpr(saturate)
610 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
612 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
614 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
616 val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 448.0, -448.0);
621 __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, 1.f, 0);
629 bool stochastic_rounding =
false,
632 static __device__
fp8_storage_t cast_to_f8_from_f16(_Float16 v,
unsigned int rng = 0)
647 if constexpr(saturate)
649 if((val.i32val & 0x7FFF) != 0x7FFF)
651 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
656 __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, 1.f, 0);
663 bool stochastic_rounding =
false,
668 #if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
670 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
671 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
685 if constexpr(saturate)
687 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
689 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
691 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
693 val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 57344.0, -57344.0);
698 __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, 1.f, 0);
706 bool stochastic_rounding =
false,
709 static __device__
fp8_storage_t cast_to_f8_from_bf16(ushort v,
unsigned int rng = 0)
718 constexpr
unsigned int i32val = 0;
719 val.bhalf_vec[0] = v;
721 if constexpr(saturate)
723 if((val.i32val & 0x7FFF) != 0x7FFF)
726 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
727 bit_cast<float>(
uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
732 val.i32val = __builtin_amdgcn_cvt_scalef32_sr_fp8_bf16(
733 i32val, val.bhalf_vec[0], rng, 1.f, 0);
740 bool stochastic_rounding =
false,
747 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
748 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
753 bool stochastic_rounding =
false,
756 static __device__
fp8_storage_t cast_to_f8_from_bf16(ushort v,
unsigned int rng = 0)
765 constexpr
unsigned int i32val = 0;
766 val.bhalf_vec[0] = v;
768 if constexpr(saturate)
770 if((val.i32val & 0x7FFF) != 0x7FFF)
772 val.bhalf_vec[0] = ushort(
773 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
774 bit_cast<float>(
uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
779 val.i32val = __builtin_amdgcn_cvt_scalef32_sr_bf8_bf16(
780 i32val, val.bhalf_vec[0], rng, 1.f, 0);
787 bool stochastic_rounding =
false,
794 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
795 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
800 bool stochastic_rounding =
false,
803 static __device__
fp8_storage_t cast_to_f8_from_bf16(ushort v,
unsigned int rng = 0)
816 val.bhalf_vec[0] = v;
818 if constexpr(saturate)
820 if((val.i32val & 0x7FFF) != 0x7FFF)
823 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
824 bit_cast<float>(
uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
830 __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, 1.f, 0);
837 bool stochastic_rounding =
false,
842 #if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
844 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
845 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
859 if constexpr(saturate)
861 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
864 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
865 bit_cast<float>(
uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
868 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
871 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
872 bit_cast<float>(
uint32_t{val.bhalf_vec[1]} << 16), 448.0, -448.0)) >>
878 __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, 1.f, 0);
886 bool stochastic_rounding =
false,
889 static __device__
fp8_storage_t cast_to_f8_from_bf16(ushort v,
unsigned int rng = 0)
902 val.bhalf_vec[0] = v;
904 if constexpr(saturate)
906 if((val.i32val & 0x7FFF) != 0x7FFF)
908 val.bhalf_vec[0] = ushort(
909 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
910 bit_cast<float>(
uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
916 __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, 1.f, 0);
923 bool stochastic_rounding =
false,
940 if constexpr(saturate)
942 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
944 val.bhalf_vec[0] = ushort(
945 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
946 bit_cast<float>(
uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
949 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
951 val.bhalf_vec[1] = ushort(
952 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
953 bit_cast<float>(
uint32_t{val.bhalf_vec[1]} << 16), 57344.0, -57344.0)) >>
959 __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, 1.f, 0);
965 #if CK_FP8_CVT_FAST_PATH
968 template <ck_fp8_
interpretation_t
interpret,
bool saturate,
bool stochastic_rounding = false>
969 static __device__
fp8_storage_t cast_to_f8_from_f32(
float v,
unsigned int rng = 0)
976 unsigned char i8val[4];
979 unsigned int ival = 0;
982 if constexpr(saturate)
986 if((val.i32val & 0x7F800000) != 0x7F800000)
988 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
993 if((val.i32val & 0x7F800000) != 0x7F800000)
995 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
1000 if((val.i32val & 0x7F800000) != 0x7F800000)
1002 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
1007 if constexpr(stochastic_rounding)
1011 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
1012 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0);
1014 i8data = val.i8val[0];
1020 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
false)
1021 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
1026 i8data = val.i8val[0];
1031 template <ck_fp8_
interpretation_t
interpret,
bool saturate,
bool stochastic_rounding = false>
1034 if constexpr(stochastic_rounding)
1038 cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[0], rng),
1039 cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[1], rng)};
1046 unsigned int i32val;
1047 unsigned char i8val[4];
1053 unsigned int ival = 0;
1055 if constexpr(saturate)
1059 if((val0.i32val & 0x7F800000) != 0x7F800000)
1061 val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 240.0, -240.0);
1063 if((val1.i32val & 0x7F800000) != 0x7F800000)
1065 val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 240.0, -240.0);
1070 if((val0.i32val & 0x7F800000) != 0x7F800000)
1072 val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 448.0, -448.0);
1074 if((val1.i32val & 0x7F800000) != 0x7F800000)
1076 val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 448.0, -448.0);
1081 if((val0.i32val & 0x7F800000) != 0x7F800000)
1083 val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 57344.0, -57344.0);
1085 if((val1.i32val & 0x7F800000) != 0x7F800000)
1087 val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 57344.0, -57344.0);
1096 ival = __builtin_amdgcn_cvt_pk_fp8_f32(val0.fval, val1.fval, ival,
false);
1100 ival = __builtin_amdgcn_cvt_pk_bf8_f32(val0.fval, val1.fval, ival,
false);
1113 template <
typename T,
int wm,
int we,
bool is_fnuz,
bool clip = false,
bool stoch = false>
1114 __host__ __device__
static inline fp8_storage_t cast_to_f8(T _x,
unsigned int rng = 0)
1119 static_assert(is_half || is_float || is_double,
1120 "Only half, float and double can be cast to f8");
1122 constexpr
int mfmt = (
sizeof(T) == 8) ? 52 : ((
sizeof(T) == 4) ? 23 : 10);
1128 T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
1130 unsigned long long x{x_bitwise};
1132 unsigned long long head, mantissa;
1135 unsigned long long fInf, mask;
1137 if constexpr(
sizeof(T) == 8)
1139 head = x & 0xFFF0000000000000ull;
1140 mantissa = x & 0xFFFFFFFFFFFFFull;
1141 exponent = (head >> 52) & 0x7FF;
1144 fInf = 0x7FF0000000000000ull;
1145 mask = 0x7FFFFFFFFFFFFFFFull;
1147 else if constexpr(
sizeof(T) == 4)
1149 head = x & 0xFF800000;
1150 mantissa = x & 0x7FFFFF;
1151 exponent = (head >> 23) & 0xFF;
1160 mantissa = x & 0x3FF;
1161 exponent = (head >> 10) & 0x1F;
1167 unsigned int signed_inf = 0;
1168 unsigned int nan = 0;
1169 if constexpr(is_fnuz)
1171 signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
1176 if constexpr(we == 4)
1178 signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
1182 signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
1184 nan = (sign << 7) + 0x7f;
1187 unsigned long long ifmax = 0;
1188 if constexpr(
sizeof(T) == 8)
1190 if constexpr(we == 5)
1192 ifmax = 0x40EC000000000000ull;
1196 if constexpr(is_fnuz)
1198 ifmax = 0x406E000000000000ull;
1202 ifmax = 0x407C000000000000ull;
1206 else if(
sizeof(T) == 4)
1208 if constexpr(we == 5)
1214 if constexpr(is_fnuz)
1226 if constexpr(we == 5)
1232 if constexpr(is_fnuz)
1243 if((x & fInf) == fInf)
1245 if constexpr(is_fnuz)
1248 return mantissa != 0 ? nan : signed_inf;
1251 if((x & mask) > ifmax)
1269 const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
1270 const int f8_denormal_act_exponent = 1 - f8_bias;
1275 int act_exponent, f8_exponent, exponent_diff;
1286 act_exponent = exponent - bias + 1;
1287 exponent_diff = f8_denormal_act_exponent -
1292 act_exponent = exponent - bias;
1293 if(act_exponent <= f8_denormal_act_exponent)
1300 exponent_diff = f8_denormal_act_exponent - act_exponent;
1308 mantissa += (1ull << mfmt);
1311 bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
1312 (1ull << (mfmt - wm + exponent_diff - 1));
1320 if(exponent_diff > 0)
1321 mantissa >>= exponent_diff;
1322 else if(exponent_diff == -1)
1323 mantissa <<= -exponent_diff;
1324 bool implicit_one = mantissa & (1ull << mfmt);
1328 (act_exponent + exponent_diff) + f8_bias - (implicit_one ? 0 : 1);
1331 unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
1333 mantissa & (1ull << (mfmt - wm));
1335 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
1338 if(f8_exponent == 0)
1340 if((1ull << mfmt) & mantissa)
1347 if((1ull << (mfmt + 1)) & mantissa)
1354 mantissa >>= (mfmt - wm);
1357 const int max_exp = (1 << we) - 1;
1358 if(f8_exponent > max_exp)
1362 mantissa = (1 << wm) - 1;
1363 f8_exponent = max_exp;
1371 if(f8_exponent == 0 && mantissa == 0)
1372 return is_fnuz ? 0 : (sign << 7);
1373 mantissa &= (1 << wm) - 1;
1374 return (sign << 7) | (f8_exponent << wm) | mantissa;
1388 bool stochastic_rounding =
false>
1389 #if CK_FP8_CVT_FAST_PATH
1390 __host__ __device__
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
1392 __is_interpret_supported(interp);
1394 if constexpr(stochastic_rounding)
1396 #if defined(__gfx950__)
1398 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1401 constexpr
int seed = 1254739;
1402 #ifndef CK_CODE_GEN_RTC
1403 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&f), f);
1405 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&f), f);
1409 return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1413 __host__ __device__
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
1416 __host__
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
1420 if constexpr(stochastic_rounding)
1422 #if defined(__gfx950__)
1424 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1427 constexpr
int seed = 1254739;
1428 #ifndef CK_CODE_GEN_RTC
1429 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&f), f);
1431 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&f), f);
1438 return cast_to_f8<float,
1443 stochastic_rounding>(f, rng);
1447 return cast_to_f8<float,
1452 stochastic_rounding>(f, rng);
1456 return cast_to_f8<float,
1461 stochastic_rounding>(f, rng);
1465 return cast_to_f8<float,
1470 stochastic_rounding>(f, rng);
1474 __hip_assert(
false &&
"FP8 type is not supported by current target device");
1491 bool stochastic_rounding =
false>
1492 #if CK_FP8_CVT_FAST_PATH
1495 __is_interpret_supported(interp);
1497 if constexpr(stochastic_rounding)
1499 #if defined(__gfx950__)
1501 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1504 constexpr
int seed = 1254739;
1505 #ifndef CK_CODE_GEN_RTC
1506 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&f), f[0]);
1508 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&f), f[0]);
1512 return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1522 return fp8x2_storage_t{cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[0]),
1523 cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[1])};
1538 bool stochastic_rounding =
false>
1539 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1540 __host__ __device__
static inline fp8_storage_t cvt_half_t_to_fp8(
const _Float16 x)
1542 __host__
static inline fp8_storage_t cvt_half_t_to_fp8(
const _Float16 x)
1546 __is_interpret_supported(interp);
1548 if constexpr(stochastic_rounding)
1550 #if defined(__gfx950__)
1552 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1555 constexpr
int seed = 1254739;
1556 #ifndef CK_CODE_GEN_RTC
1557 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&x), x);
1559 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&x), x);
1563 #if defined(__gfx950__)
1564 return cast_to_f8_from_f16<interp,
1566 stochastic_rounding>(x, rng);
1569 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1570 static_cast<float>(x));
1586 bool stochastic_rounding =
false>
1587 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1594 __is_interpret_supported(interp);
1596 if constexpr(stochastic_rounding)
1598 #if defined(__gfx950__)
1600 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1603 constexpr
int seed = 1254739;
1604 #ifndef CK_CODE_GEN_RTC
1605 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&x), x[0]);
1607 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&x), x[0]);
1611 #if defined(__gfx950__)
1612 return cast_to_f8_from_f16<interp,
1614 stochastic_rounding>(x, rng);
1617 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1618 float2_t{
static_cast<float>(x[0]),
static_cast<float>(x[1])});
1634 bool stochastic_rounding =
false>
1635 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1636 __host__ __device__
static inline fp8_storage_t cvt_bhalf_t_to_fp8(
const ushort x)
1638 __host__
static inline fp8_storage_t cvt_bhalf_t_to_fp8(
const ushort x)
1642 __is_interpret_supported(interp);
1644 if constexpr(stochastic_rounding)
1646 #if defined(__gfx950__)
1648 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1651 constexpr
int seed = 1254739;
1652 #ifndef CK_CODE_GEN_RTC
1653 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&x),
1654 static_cast<float>(x));
1656 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&x),
static_cast<float>(x));
1660 #if defined(__gfx950__)
1661 return cast_to_f8_from_bf16<interp,
1663 stochastic_rounding>(x, rng);
1666 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1667 bit_cast<float>(
uint32_t{x} << 16));
1683 bool stochastic_rounding =
false>
1684 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1690 #if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
1691 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1693 bit_cast<float>(
uint32_t{x[1]} << 16)});
1696 __is_interpret_supported(interp);
1698 if constexpr(stochastic_rounding)
1700 #if defined(__gfx950__)
1702 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1705 constexpr
int seed = 1254739;
1706 #ifndef CK_CODE_GEN_RTC
1707 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&x),
1708 static_cast<float>(x[0]));
1710 rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&x),
1711 static_cast<float>(x[0]));
1715 #if defined(__gfx950__)
1716 return cast_to_f8_from_bf16<interp,
1718 stochastic_rounding>(x, rng);
1721 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1723 bit_cast<float>(
uint32_t{x[1]} << 16)});
1732 using f8_t = f8_ocp_t;
1733 using bf8_t = bf8_ocp_t;
1734 #define CK_FP8_TYPE_FNUZ 0
1735 #define CK_FP8_TYPE_OCP 1
1739 #define CK_FP8_TYPE_FNUZ 1
1740 #define CK_FP8_TYPE_OCP 0
#define __fp8_impl_assert_fnuz_support(interp)
Definition: amd_ck_fp8.hpp:439
#define __fp8_impl_assert_ocp_support(interp)
Definition: amd_ck_fp8.hpp:431
ushort ushortx2_t
Definition: amd_ck_fp8.hpp:65
short shortx2_t
Definition: amd_ck_fp8.hpp:66
float float2_t
Definition: amd_ck_fp8.hpp:67
fp8_storage_t fp8x2_storage_t
Definition: amd_ck_fp8.hpp:63
_Float16 half2_t
Definition: amd_ck_fp8.hpp:64
__host__ constexpr __device__ Y bit_cast(const X &x)
Definition: type.hpp:306
bf8_fnuz_t bf8_t
Definition: amd_ck_fp8.hpp:1738
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
ck_fp8_interpretation_t
Describes FP8 interpretation.
Definition: amd_ck_fp8.hpp:45
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:54
unsigned _BitInt(8) bf8_fnuz_t
Definition: amd_ck_fp8.hpp:37
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
_BitInt(8) f8_fnuz_t
Definition: amd_ck_fp8.hpp:36
ck_saturation_t
Describes saturation behavior.
Definition: amd_ck_fp8.hpp:56
unsigned char fp8_storage_t
Definition: amd_ck_fp8.hpp:39
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
_W64 unsigned int uintptr_t
Definition: stdint.h:165
unsigned int uint32_t
Definition: stdint.h:126
Definition: amd_ck_fp8.hpp:344
__host__ constexpr __device__ bool operator==(const bf8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:355
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:345
data_type data
Definition: amd_ck_fp8.hpp:346
Definition: amd_ck_fp8.hpp:298
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:299
data_type data
Definition: amd_ck_fp8.hpp:300
__host__ constexpr __device__ bool operator==(const f8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:309