7 #if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ 
    8 #define CK_MX_FP8_CVT_FAST_PATH 1 
   10 #define CK_MX_FP8_CVT_FAST_PATH 0 
   16 #if CK_MX_FP8_CVT_FAST_PATH 
   17 template <ck_fp8_
interpretation_t 
interpret>
 
   18 static __device__ 
float cast_to_f32_from_f8_scaled(
float scale, 
fp8_storage_t v)
 
   23         unsigned char i8val[4];
 
   29                   "Only OCP interpretations are supported");
 
   33         return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0);
 
   37         return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0);
 
   41 template <ck_fp8_
interpretation_t 
interpret>
 
   44     const auto i16val = bit_cast<uint16_t>(v);
 
   48                   "Only OCP interpretations are supported");
 
   52         return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0);
 
   56         return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0);
 
   60 template <ck_fp8_
interpretation_t 
interpret, 
bool stochastic_rounding = false>
 
   61 static __device__ 
fp8_storage_t cast_to_f8_from_f32_scaled(
float v,
 
   75         vector_type<int16_t, 2>::type v2i16;
 
   82     if constexpr(stochastic_rounding)
 
   86                 ? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0)
 
   87                 : __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0);
 
   99             ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( ret.v2i16,
 
  108             ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( ret.v2i16,
 
  115         i8data = ret.v4i8[0];
 
  120 template <ck_fp8_
interpretation_t 
interpret, 
bool stochastic_rounding = false>
 
  122                                                              unsigned int rng = 0,
 
  129         vector_type<int16_t, 2>::type v2i16;
 
  130         StaticallyIndexedArray<fp8x2_storage_t, 2> v2f8x2;
 
  133     if constexpr(stochastic_rounding)
 
  138             ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0);
 
  139             f8x2[0]  = ret.v2f8x2(Number<0>{})[0];
 
  140             ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0);
 
  141             f8x2[1]  = ret.v2f8x2(Number<0>{})[0];
 
  145             ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0);
 
  146             f8x2[0]  = ret.v2f8x2(Number<0>{})[0];
 
  147             ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0);
 
  148             f8x2[1]  = ret.v2f8x2(Number<0>{})[0];
 
  160             ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( ret.v2i16,
 
  169             ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( ret.v2i16,
 
  176         return ret.v2f8x2(Number<0>{});
 
  182 #if CK_MX_FP8_CVT_FAST_PATH 
  193 template <ck_fp8_
interpretation_t 
interp, 
bool stochastic_rounding = false>
 
  194 __host__ __device__ 
static inline fp8_storage_t cvt_float_to_fp8_scaled(
const float f, 
float scale)
 
  196     __is_interpret_supported(interp);
 
  198     if constexpr(stochastic_rounding)
 
  201         rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
 
  204     return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
 
  217 template <ck_fp8_
interpretation_t 
interp, 
bool stochastic_rounding = false>
 
  221     __is_interpret_supported(interp);
 
  223     if constexpr(stochastic_rounding)
 
  226         rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
 
  229     return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
 
  244 template <ck_fp8_
interpretation_t 
interp, 
bool stochastic_rounding = false>
 
  245 __host__ __device__ 
static inline fp8_storage_t cvt_float_to_fp8_scaled(
const float f, 
float scale)
 
  250                   "Only OCP interpretations are supported");
 
  253     if constexpr(stochastic_rounding)
 
  255         constexpr 
int seed = 1254739;
 
  256         rng                = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&f), f);
 
  261         return cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f / scale, rng);
 
  265         return cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f / scale, rng);
 
  269         __hip_assert(
false && 
"FP8 type is not supported by current target device");
 
  284 template <ck_fp8_
interpretation_t 
interp, 
bool stochastic_rounding = false>
 
  291                   "Only OCP interpretations are supported");
 
  294     if constexpr(stochastic_rounding)
 
  296         constexpr 
int seed = 1254739;
 
  297         rng                = prand_generator<float, seed>(
reinterpret_cast<uintptr_t>(&f), f[0]);
 
  302         return {cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[0] / scale, rng),
 
  303                 cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[1] / scale, rng)};
 
  307         return {cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[0] / scale, rng),
 
  308                 cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[1] / scale, rng)};
 
  312         __hip_assert(
false && 
"FP8 type is not supported by current target device");
 
  322 template <
typename Y, 
typename X>
 
  326 template <
typename Y, 
typename X>
 
  333     return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
 
  340     return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
 
  348     return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
 
  356     return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
 
  377         [&](
auto i) { out.fp8_2x8[i] = mxf8_convert_rne<f8x2_ocp_t>(in.float_2x8[i], scale); });
 
  400         [&](
auto i) { out.bf8_2x8[i] = mxf8_convert_rne<bf8x2_ocp_t>(in.float_2x8[i], scale); });
 
  423         [&](
auto i) { out.fp8_16x2[i] = mxf8_convert_rne<f8x16_ocp_t>(in.float_16x2[i], scale); });
 
  446         [&](
auto i) { out.bf8_16x2[i] = mxf8_convert_rne<bf8x16_ocp_t>(in.float_16x2[i], scale); });
 
  455     return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
 
  463         fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
 
  471         fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
 
  480         fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
 
  501         [&](
auto i) { out.fp8_2x8[i] = mxf8_convert_sr<f8x2_ocp_t>(in.float_2x8[i], scale); });
 
  524         [&](
auto i) { out.bf8_2x8[i] = mxf8_convert_sr<bf8x2_ocp_t>(in.float_2x8[i], scale); });
 
  547         [&](
auto i) { out.fp8_16x2[i] = mxf8_convert_sr<f8x16_ocp_t>(in.float_16x2[i], scale); });
 
  570         [&](
auto i) { out.bf8_16x2[i] = mxf8_convert_sr<bf8x16_ocp_t>(in.float_16x2[i], scale); });
 
float float2_t
Definition: amd_ck_fp8.hpp:92
 
fp8_storage_t fp8x2_storage_t
Definition: amd_ck_fp8.hpp:88
 
__host__ __device__ f8x16_ocp_t mxf8_convert_sr< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:485
 
__host__ __device__ f8x2_ocp_t mxf8_convert_rne< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:345
 
typename vector_type< float, 16 >::type float16_t
Definition: dtype_vector.hpp:2148
 
__host__ constexpr __device__ Y mxf8_convert_rne(X x, float scale)
 
__host__ __device__ f8_ocp_t mxf8_convert_rne< f8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:331
 
typename vector_type< bf8_ocp_t, 32 >::type bf8x32_ocp_t
Definition: dtype_vector.hpp:2212
 
__host__ __device__ bf8_ocp_t mxf8_convert_sr< bf8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:460
 
__host__ __device__ f8_ocp_t mxf8_convert_sr< f8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:453
 
__host__ __device__ bf8x32_ocp_t mxf8_convert_rne< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:430
 
typename vector_type< float, 2 >::type float2_t
Definition: dtype_vector.hpp:2145
 
typename vector_type< f8_ocp_t, 2 >::type f8x2_ocp_t
Definition: dtype_vector.hpp:2200
 
__host__ __device__ bf8_ocp_t mxf8_convert_rne< bf8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:338
 
typename vector_type< bf8_ocp_t, 2 >::type bf8x2_ocp_t
Definition: dtype_vector.hpp:2208
 
__host__ constexpr __device__ Y mxf8_convert_sr(X x, float scale)
 
__host__ __device__ bf8x16_ocp_t mxf8_convert_sr< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:508
 
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:43
 
typename vector_type< f8_ocp_t, 32 >::type f8x32_ocp_t
Definition: dtype_vector.hpp:2204
 
__host__ __device__ bf8x16_ocp_t mxf8_convert_rne< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:384
 
typename vector_type< f8_ocp_t, 16 >::type f8x16_ocp_t
Definition: dtype_vector.hpp:2203
 
__host__ __device__ f8x16_ocp_t mxf8_convert_rne< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:361
 
__host__ __device__ bf8x32_ocp_t mxf8_convert_sr< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:554
 
__host__ __device__ f8x32_ocp_t mxf8_convert_sr< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:531
 
__host__ __device__ bf8x2_ocp_t mxf8_convert_sr< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:476
 
typename vector_type< bf8_ocp_t, 16 >::type bf8x16_ocp_t
Definition: dtype_vector.hpp:2211
 
__host__ __device__ f8x32_ocp_t mxf8_convert_rne< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:407
 
typename vector_type< float, 32 >::type float32_t
Definition: dtype_vector.hpp:2149
 
__host__ __device__ bf8x2_ocp_t mxf8_convert_rne< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:353
 
__host__ __device__ f8x2_ocp_t mxf8_convert_sr< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:468
 
unsigned char fp8_storage_t
Definition: amd_ck_fp8.hpp:61
 
_W64 unsigned int uintptr_t
Definition: stdint.h:164
 
unsigned int uint32_t
Definition: stdint.h:126
 
Definition: amd_ck_fp8.hpp:369
 
Definition: amd_ck_fp8.hpp:323
 
Definition: functional2.hpp:33