/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/mxf4_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/mxf4_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/mxf4_utils.hpp Source File
mxf4_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #ifndef CK_CODE_GEN_RTC
5 #pragma once
6 
9 
10 namespace ck::utils {
11 
12 template <>
13 __host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
14  f4_t const dataBytes [[maybe_unused]])
15 {
16  // no need to check for data as it does not have NaN representation
17  return scale.is_nan();
18 }
19 
20 // no infinity representation in ocp_e2m1_mxfp4 will always return false
21 template <>
22 __host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
23  f4_t const data [[maybe_unused]])
24 {
25  // no inf representation for ocp_e2m1_mxfp4
26  return false;
27 }
28 
29 template <>
30 __host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
31  f4_t const data)
32 {
33  // no need to check for scale as it does not have a 0 representation
34  f4_t result = (data & 0b00001111) & NumericUtils<f4_t>::set_sign_mask;
35 
36  return result == 0b0;
37 }
38 
39 template <>
40 __host__ __device__ inline float to_float<f4_t>(e8m0_bexp_t const scale, f4_t const data)
41 {
42  if(is_nan<f4_t>(scale, data))
44 
45  if(is_zero<f4_t>(scale, data))
46  return 0.0f;
47 
48  f4_t prepared_data = data & 0b00001111;
49 
50  int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
51 
52  return convert_to_float<f4_t>(prepared_data, scale_exp);
53 }
54 
55 template <>
56 __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
57 {
58  cvt t;
59  t.value_float = value;
60  uint32_t sign = t.value_bitwise >> 31;
61 
62  if(std::isnan(value))
63  {
64 
67  }
68 
69  if(std::abs(value) > NumericLimits<f4_t>::DataMaxNorm()) // covers inf case as well
72 
73  f4_t res = convert_to_type<f4_t>(value);
74 
79 
80  return res;
81 }
82 
83 template <>
84 __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32_t seed)
85 {
86  cvt t;
87  t.value_float = value;
88  uint32_t sign = t.value_bitwise >> 31;
89 
90  if(std::isnan(value))
93 
94  if(std::abs(value) > NumericLimits<f4_t>::DataMaxNorm()) // covers inf case as well
97 
98  f4_t res = convert_to_type_sr<f4_t>(value, seed);
99 
104 
105  return res;
106 }
107 } // namespace ck::utils
108 #endif
Definition: check_err.hpp:24
__host__ __device__ bool is_nan< f4_t >(e8m0_bexp_t const scale, f4_t const dataBytes[[maybe_unused]])
Definition: mxf4_utils.hpp:13
__host__ __device__ f4_t sat_convert_to_type_sr< f4_t >(float value, uint32_t seed)
Definition: mxf4_utils.hpp:84
__host__ constexpr __device__ int32_t get_exponent_value< e8m0_bexp_t >(e8m0_bexp_t x)
Definition: e8m0.hpp:73
__host__ __device__ float to_float< f4_t >(e8m0_bexp_t const scale, f4_t const data)
Definition: mxf4_utils.hpp:40
__host__ __device__ bool is_inf< f4_t >(e8m0_bexp_t const scale[[maybe_unused]], f4_t const data[[maybe_unused]])
Definition: mxf4_utils.hpp:22
__host__ __device__ f4_t sat_convert_to_type< f4_t >(float value)
Definition: mxf4_utils.hpp:56
__host__ __device__ bool is_zero< f4_t >(e8m0_bexp_t const scale[[maybe_unused]], f4_t const data)
Definition: mxf4_utils.hpp:30
unsigned _BitInt(4) f4_t
Definition: data_type.hpp:32
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned int uint32_t
Definition: stdint.h:126
Definition: numeric_limits.hpp:309
__host__ static constexpr __device__ T QuietNaN()
Definition: numeric_limits.hpp:313
Definition: numeric_utils.hpp:10
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
__host__ constexpr __device__ bool is_nan() const
Definition: e8m0.hpp:64
Definition: mxfp_utils.hpp:14
float value_float
Definition: mxfp_utils.hpp:15
uint32_t value_bitwise
Definition: mxfp_utils.hpp:16