4 #if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
5 #define CK_MX_FP8_CVT_FAST_PATH 1
7 #define CK_MX_FP8_CVT_FAST_PATH 0
13 #if CK_MX_FP8_CVT_FAST_PATH
14 template <ck_fp8_
interpretation_t
interpret>
15 static __device__
float cast_to_f32_from_f8_scaled(
float scale,
fp8_storage_t v)
20 unsigned char i8val[4];
26 "Only OCP interpretations are supported");
30 return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0);
34 return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0);
38 template <ck_fp8_
interpretation_t
interpret>
41 const auto i16val = bit_cast<uint16_t>(v);
45 "Only OCP interpretations are supported");
49 return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0);
53 return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0);
57 template <ck_fp8_
interpretation_t
interpret,
bool stochastic_rounding = false>
58 static __device__
fp8_storage_t cast_to_f8_from_f32_scaled(
float v,
72 vector_type<int16_t, 2>::type v2i16;
79 if constexpr(stochastic_rounding)
83 ? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0)
84 : __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0);
96 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( ret.v2i16,
105 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( ret.v2i16,
112 i8data = ret.v4i8[0];
117 template <ck_fp8_
interpretation_t
interpret,
bool stochastic_rounding = false>
119 unsigned int rng = 0,
126 vector_type<int16_t, 2>::type v2i16;
127 StaticallyIndexedArray<fp8x2_storage_t, 2> v2f8x2;
130 if constexpr(stochastic_rounding)
135 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0);
136 f8x2[0] = ret.v2f8x2(Number<0>{})[0];
137 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0);
138 f8x2[1] = ret.v2f8x2(Number<0>{})[0];
142 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0);
143 f8x2[0] = ret.v2f8x2(Number<0>{})[0];
144 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0);
145 f8x2[1] = ret.v2f8x2(Number<0>{})[0];
157 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( ret.v2i16,
166 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( ret.v2i16,
173 return ret.v2f8x2(Number<0>{});
179 #if CK_MX_FP8_CVT_FAST_PATH
190 template <ck_fp8_
interpretation_t
interp,
bool stochastic_rounding = false>
191 __host__ __device__
static inline fp8_storage_t cvt_float_to_fp8_scaled(
const float f,
float scale)
193 __is_interpret_supported(interp);
195 if constexpr(stochastic_rounding)
197 constexpr
int seed = 1254739;
198 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t
>(&f), f);
200 return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
213 template <ck_fp8_
interpretation_t
interp,
bool stochastic_rounding = false>
217 __is_interpret_supported(interp);
219 if constexpr(stochastic_rounding)
221 constexpr
int seed = 1254739;
222 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t
>(&f), f[0]);
224 return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
239 template <ck_fp8_
interpretation_t
interp,
bool stochastic_rounding = false>
240 __host__ __device__
static inline fp8_storage_t cvt_float_to_fp8_scaled(
const float f,
float scale)
245 "Only OCP interpretations are supported");
248 if constexpr(stochastic_rounding)
250 constexpr
int seed = 1254739;
251 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t
>(&f), f);
256 return cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f / scale, rng);
260 return cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f / scale, rng);
264 __hip_assert(
false &&
"FP8 type is not supported by current target device");
279 template <ck_fp8_
interpretation_t
interp,
bool stochastic_rounding = false>
286 "Only OCP interpretations are supported");
289 if constexpr(stochastic_rounding)
291 constexpr
int seed = 1254739;
292 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t
>(&f), f[0]);
297 return {cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[0] / scale, rng),
298 cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[1] / scale, rng)};
302 return {cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[0] / scale, rng),
303 cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[1] / scale, rng)};
307 __hip_assert(
false &&
"FP8 type is not supported by current target device");
317 template <
typename Y,
typename X>
321 template <
typename Y,
typename X>
328 return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
335 return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
343 return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
351 return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
372 [&](
auto i) { out.fp8_2x8[i] = mxf8_convert_rne<f8x2_ocp_t>(in.float_2x8[i], scale); });
395 [&](
auto i) { out.bf8_2x8[i] = mxf8_convert_rne<bf8x2_ocp_t>(in.float_2x8[i], scale); });
418 [&](
auto i) { out.fp8_16x2[i] = mxf8_convert_rne<f8x16_ocp_t>(in.float_16x2[i], scale); });
441 [&](
auto i) { out.bf8_16x2[i] = mxf8_convert_rne<bf8x16_ocp_t>(in.float_16x2[i], scale); });
450 return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
458 fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
466 fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
475 fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
496 [&](
auto i) { out.fp8_2x8[i] = mxf8_convert_sr<f8x2_ocp_t>(in.float_2x8[i], scale); });
519 [&](
auto i) { out.bf8_2x8[i] = mxf8_convert_sr<bf8x2_ocp_t>(in.float_2x8[i], scale); });
542 [&](
auto i) { out.fp8_16x2[i] = mxf8_convert_sr<f8x16_ocp_t>(in.float_16x2[i], scale); });
565 [&](
auto i) { out.bf8_16x2[i] = mxf8_convert_sr<bf8x16_ocp_t>(in.float_16x2[i], scale); });
float float2_t
Definition: amd_ck_fp8.hpp:67
fp8_storage_t fp8x2_storage_t
Definition: amd_ck_fp8.hpp:66
__host__ __device__ f8x16_ocp_t mxf8_convert_sr< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:480
__host__ __device__ f8x2_ocp_t mxf8_convert_rne< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:340
typename vector_type< float, 16 >::type float16_t
Definition: data_type.hpp:2484
__host__ constexpr __device__ Y mxf8_convert_rne(X x, float scale)
__host__ __device__ f8_ocp_t mxf8_convert_rne< f8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:326
typename vector_type< bf8_ocp_t, 32 >::type bf8x32_ocp_t
Definition: data_type.hpp:2549
__host__ __device__ bf8_ocp_t mxf8_convert_sr< bf8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:455
__host__ __device__ f8_ocp_t mxf8_convert_sr< f8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:448
__host__ __device__ bf8x32_ocp_t mxf8_convert_rne< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:425
typename vector_type< float, 2 >::type float2_t
Definition: data_type.hpp:2481
typename vector_type< f8_ocp_t, 2 >::type f8x2_ocp_t
Definition: data_type.hpp:2537
__host__ __device__ bf8_ocp_t mxf8_convert_rne< bf8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:333
typename vector_type< bf8_ocp_t, 2 >::type bf8x2_ocp_t
Definition: data_type.hpp:2545
__host__ constexpr __device__ Y mxf8_convert_sr(X x, float scale)
__host__ __device__ bf8x16_ocp_t mxf8_convert_sr< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:503
typename vector_type< f8_ocp_t, 32 >::type f8x32_ocp_t
Definition: data_type.hpp:2541
__host__ __device__ bf8x16_ocp_t mxf8_convert_rne< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:379
typename vector_type< f8_ocp_t, 16 >::type f8x16_ocp_t
Definition: data_type.hpp:2540
__host__ __device__ f8x16_ocp_t mxf8_convert_rne< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:356
__host__ __device__ bf8x32_ocp_t mxf8_convert_sr< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:549
__host__ __device__ f8x32_ocp_t mxf8_convert_sr< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:526
__host__ __device__ bf8x2_ocp_t mxf8_convert_sr< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:471
typename vector_type< bf8_ocp_t, 16 >::type bf8x16_ocp_t
Definition: data_type.hpp:2548
__host__ __device__ f8x32_ocp_t mxf8_convert_rne< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:402
typename vector_type< float, 32 >::type float32_t
Definition: data_type.hpp:2485
__host__ __device__ bf8x2_ocp_t mxf8_convert_rne< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:348
__host__ __device__ f8x2_ocp_t mxf8_convert_sr< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:463
unsigned char fp8_storage_t
Definition: amd_ck_fp8.hpp:42
Definition: amd_ck_fp8.hpp:344
Definition: amd_ck_fp8.hpp:298
Definition: functional2.hpp:31