7 #if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
8 #define CK_MX_FP8_CVT_FAST_PATH 1
10 #define CK_MX_FP8_CVT_FAST_PATH 0
16 #if CK_MX_FP8_CVT_FAST_PATH
17 template <ck_fp8_
interpretation_t
interpret>
18 static __device__
float cast_to_f32_from_f8_scaled(
float scale,
fp8_storage_t v)
23 unsigned char i8val[4];
29 "Only OCP interpretations are supported");
33 return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0);
37 return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0);
41 template <ck_fp8_
interpretation_t
interpret>
44 const auto i16val = bit_cast<uint16_t>(v);
48 "Only OCP interpretations are supported");
52 return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0);
56 return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0);
60 template <ck_fp8_
interpretation_t
interpret,
bool stochastic_rounding = false>
61 static __device__
fp8_storage_t cast_to_f8_from_f32_scaled(
float v,
75 vector_type<int16_t, 2>::type v2i16;
82 if constexpr(stochastic_rounding)
86 ? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0)
87 : __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0);
99 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( ret.v2i16,
108 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( ret.v2i16,
115 i8data = ret.v4i8[0];
120 template <ck_fp8_
interpretation_t
interpret,
bool stochastic_rounding = false>
122 unsigned int rng = 0,
129 vector_type<int16_t, 2>::type v2i16;
130 StaticallyIndexedArray<fp8x2_storage_t, 2> v2f8x2;
133 if constexpr(stochastic_rounding)
138 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0);
139 f8x2[0] = ret.v2f8x2(Number<0>{})[0];
140 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0);
141 f8x2[1] = ret.v2f8x2(Number<0>{})[0];
145 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0);
146 f8x2[0] = ret.v2f8x2(Number<0>{})[0];
147 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0);
148 f8x2[1] = ret.v2f8x2(Number<0>{})[0];
160 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( ret.v2i16,
169 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( ret.v2i16,
176 return ret.v2f8x2(Number<0>{});
182 #if CK_MX_FP8_CVT_FAST_PATH
193 template <ck_fp8_
interpretation_t
interp,
bool stochastic_rounding = false>
194 __host__ __device__
static inline fp8_storage_t cvt_float_to_fp8_scaled(
const float f,
float scale)
196 __is_interpret_supported(interp);
198 if constexpr(stochastic_rounding)
201 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
204 return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
217 template <ck_fp8_
interpretation_t
interp,
bool stochastic_rounding = false>
221 __is_interpret_supported(interp);
223 if constexpr(stochastic_rounding)
226 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
229 return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
244 template <ck_fp8_
interpretation_t
interp,
bool stochastic_rounding = false>
245 __host__ __device__
static inline fp8_storage_t cvt_float_to_fp8_scaled(
const float f,
float scale)
250 "Only OCP interpretations are supported");
253 if constexpr(stochastic_rounding)
255 constexpr
int seed = 1254739;
256 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&f), f);
261 return cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f / scale, rng);
265 return cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f / scale, rng);
269 __hip_assert(
false &&
"FP8 type is not supported by current target device");
284 template <ck_fp8_
interpretation_t
interp,
bool stochastic_rounding = false>
291 "Only OCP interpretations are supported");
294 if constexpr(stochastic_rounding)
296 constexpr
int seed = 1254739;
297 rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&f), f[0]);
302 return {cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[0] / scale, rng),
303 cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[1] / scale, rng)};
307 return {cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[0] / scale, rng),
308 cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[1] / scale, rng)};
312 __hip_assert(
false &&
"FP8 type is not supported by current target device");
322 template <
typename Y,
typename X>
326 template <
typename Y,
typename X>
333 return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
340 return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
348 return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
356 return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
377 [&](
auto i) { out.fp8_2x8[i] = mxf8_convert_rne<f8x2_ocp_t>(in.float_2x8[i], scale); });
400 [&](
auto i) { out.bf8_2x8[i] = mxf8_convert_rne<bf8x2_ocp_t>(in.float_2x8[i], scale); });
423 [&](
auto i) { out.fp8_16x2[i] = mxf8_convert_rne<f8x16_ocp_t>(in.float_16x2[i], scale); });
446 [&](
auto i) { out.bf8_16x2[i] = mxf8_convert_rne<bf8x16_ocp_t>(in.float_16x2[i], scale); });
455 return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
463 fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
471 fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
480 fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
501 [&](
auto i) { out.fp8_2x8[i] = mxf8_convert_sr<f8x2_ocp_t>(in.float_2x8[i], scale); });
524 [&](
auto i) { out.bf8_2x8[i] = mxf8_convert_sr<bf8x2_ocp_t>(in.float_2x8[i], scale); });
547 [&](
auto i) { out.fp8_16x2[i] = mxf8_convert_sr<f8x16_ocp_t>(in.float_16x2[i], scale); });
570 [&](
auto i) { out.bf8_16x2[i] = mxf8_convert_sr<bf8x16_ocp_t>(in.float_16x2[i], scale); });
float float2_t
Definition: amd_ck_fp8.hpp:67
fp8_storage_t fp8x2_storage_t
Definition: amd_ck_fp8.hpp:63
__host__ __device__ f8x16_ocp_t mxf8_convert_sr< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:485
__host__ __device__ f8x2_ocp_t mxf8_convert_rne< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:345
typename vector_type< float, 16 >::type float16_t
Definition: dtype_vector.hpp:2134
__host__ constexpr __device__ Y mxf8_convert_rne(X x, float scale)
__host__ __device__ f8_ocp_t mxf8_convert_rne< f8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:331
typename vector_type< bf8_ocp_t, 32 >::type bf8x32_ocp_t
Definition: dtype_vector.hpp:2198
__host__ __device__ bf8_ocp_t mxf8_convert_sr< bf8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:460
__host__ __device__ f8_ocp_t mxf8_convert_sr< f8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:453
__host__ __device__ bf8x32_ocp_t mxf8_convert_rne< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:430
typename vector_type< float, 2 >::type float2_t
Definition: dtype_vector.hpp:2131
typename vector_type< f8_ocp_t, 2 >::type f8x2_ocp_t
Definition: dtype_vector.hpp:2186
__host__ __device__ bf8_ocp_t mxf8_convert_rne< bf8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:338
typename vector_type< bf8_ocp_t, 2 >::type bf8x2_ocp_t
Definition: dtype_vector.hpp:2194
__host__ constexpr __device__ Y mxf8_convert_sr(X x, float scale)
__host__ __device__ bf8x16_ocp_t mxf8_convert_sr< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:508
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:54
typename vector_type< f8_ocp_t, 32 >::type f8x32_ocp_t
Definition: dtype_vector.hpp:2190
__host__ __device__ bf8x16_ocp_t mxf8_convert_rne< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:384
typename vector_type< f8_ocp_t, 16 >::type f8x16_ocp_t
Definition: dtype_vector.hpp:2189
__host__ __device__ f8x16_ocp_t mxf8_convert_rne< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:361
__host__ __device__ bf8x32_ocp_t mxf8_convert_sr< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:554
__host__ __device__ f8x32_ocp_t mxf8_convert_sr< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:531
__host__ __device__ bf8x2_ocp_t mxf8_convert_sr< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:476
typename vector_type< bf8_ocp_t, 16 >::type bf8x16_ocp_t
Definition: dtype_vector.hpp:2197
__host__ __device__ f8x32_ocp_t mxf8_convert_rne< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:407
typename vector_type< float, 32 >::type float32_t
Definition: dtype_vector.hpp:2135
__host__ __device__ bf8x2_ocp_t mxf8_convert_rne< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:353
__host__ __device__ f8x2_ocp_t mxf8_convert_sr< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:468
unsigned char fp8_storage_t
Definition: amd_ck_fp8.hpp:39
_W64 unsigned int uintptr_t
Definition: stdint.h:165
unsigned int uint32_t
Definition: stdint.h:126
Definition: amd_ck_fp8.hpp:344
Definition: amd_ck_fp8.hpp:298
Definition: functional2.hpp:33