include/ck/utility/mxf6_utils.hpp Source File

include/ck/utility/mxf6_utils.hpp Source File#

Composable Kernel: include/ck/utility/mxf6_utils.hpp Source File
mxf6_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck::utils {
10 
21 template <>
22 __host__ __device__ inline bool is_nan<f6_t>(e8m0_bexp_t const scale,
23  f6_t const dataBytes [[maybe_unused]])
24 {
25  // no need to check for data as it does not have NaN representation
26  return scale.is_nan();
27 }
28 
39 template <>
40 __host__ __device__ inline bool is_nan<bf6_t>(e8m0_bexp_t const scale,
41  bf6_t const dataBytes [[maybe_unused]])
42 {
43  // no need to check for data as it does not have NaN representation
44  return scale.is_nan();
45 }
46 
56 template <>
57 __host__ __device__ inline bool is_inf<f6_t>(e8m0_bexp_t const scale [[maybe_unused]],
58  f6_t const data [[maybe_unused]])
59 {
60  // no inf representation for fp6
61  return false;
62 }
63 
73 template <>
74 __host__ __device__ inline bool is_inf<bf6_t>(e8m0_bexp_t const scale [[maybe_unused]],
75  bf6_t const data [[maybe_unused]])
76 {
77  // no inf representation for bf6
78  return false;
79 }
80 
92 template <>
93 __host__ __device__ inline bool is_zero<f6_t>(e8m0_bexp_t const scale, f6_t const data)
94 {
95  if(is_nan<f6_t>(scale, data))
96  return false;
97 
98  // no need to check for scale as it does not have a 0 representation
99  f6_t result = (data & 0b00111111) & NumericUtils<f6_t>::set_sign_mask;
100 
101  return result == 0b0;
102 }
103 
115 template <>
116 __host__ __device__ inline bool is_zero<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
117 {
118  if(is_nan<bf6_t>(scale, data))
119  return false;
120 
121  // no need to check for scale as it does not have a 0 representation
122  bf6_t result = (data & 0b00111111) & NumericUtils<bf6_t>::set_sign_mask;
123 
124  return result == 0b0;
125 }
126 
137 template <>
138 __host__ __device__ inline float to_float<f6_t>(e8m0_bexp_t const scale, f6_t const data)
139 {
140  if(is_nan<f6_t>(scale, data))
141  return std::numeric_limits<float>::quiet_NaN();
142 
143  if(is_zero<f6_t>(scale, data))
144  return 0.0f;
145 
146  f6_t prepared_data = data & 0b00111111;
147 
148  int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
149 
150  return convert_to_float<f6_t>(prepared_data, scale_exp);
151 }
152 
163 template <>
164 __host__ __device__ inline float to_float<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
165 {
166  if(is_nan<bf6_t>(scale, data))
167  return std::numeric_limits<float>::quiet_NaN();
168 
169  if(is_zero<bf6_t>(scale, data))
170  return 0.0f;
171 
172  bf6_t prepared_data = data & 0b00111111;
173 
174  int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
175 
176  return convert_to_float<bf6_t>(prepared_data, scale_exp);
177 }
178 
189 template <>
190 __host__ __device__ inline f6_t sat_convert_to_type<f6_t>(float value)
191 {
192  cvt t;
193  t.value_float = value;
194  uint32_t sign = t.value_bitwise >> 31;
195 
196  if(std::isnan(value))
197  {
198 
201  }
202 
203  if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
206 
207  f6_t res = convert_to_type<f6_t>(value);
208 
213 
214  return res;
215 }
216 
227 template <>
228 __host__ __device__ inline bf6_t sat_convert_to_type<bf6_t>(float value)
229 {
230  cvt t;
231  t.value_float = value;
232  uint32_t sign = t.value_bitwise >> 31;
233 
234  if(std::isnan(value))
235  {
236 
239  }
240 
241  if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
244 
245  bf6_t res = convert_to_type<bf6_t>(value);
246 
251 
252  return res;
253 }
254 
265 template <>
266 __host__ __device__ inline f6_t sat_convert_to_type_sr<f6_t>(float value, uint32_t seed)
267 {
268  cvt t;
269  t.value_float = value;
270  uint32_t sign = t.value_bitwise >> 31;
271 
272  if(std::isnan(value))
275 
276  if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
279 
280  f6_t res = convert_to_type_sr<f6_t>(value, seed);
281 
286 
287  return res;
288 }
289 
300 template <>
301 __host__ __device__ inline bf6_t sat_convert_to_type_sr<bf6_t>(float value, uint32_t seed)
302 {
303  cvt t;
304  t.value_float = value;
305  uint32_t sign = t.value_bitwise >> 31;
306 
307  if(std::isnan(value))
310 
311  if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
314 
315  bf6_t res = convert_to_type_sr<bf6_t>(value, seed);
316 
321 
322  return res;
323 }
324 
325 } // namespace ck::utils
Definition: check_err.hpp:24
__host__ __device__ bool is_inf< bf6_t >(e8m0_bexp_t const scale[[maybe_unused]], bf6_t const data[[maybe_unused]])
Checks if an bf6_t value is infinite.
Definition: mxf6_utils.hpp:74
__host__ __device__ f6_t sat_convert_to_type_sr< f6_t >(float value, uint32_t seed)
Converts a float to f6_t with saturation and stochastic rounding.
Definition: mxf6_utils.hpp:266
__host__ __device__ bool is_zero< bf6_t >(e8m0_bexp_t const scale, bf6_t const data)
Checks whether an bf6_t value is zero.
Definition: mxf6_utils.hpp:116
__host__ __device__ bf6_t sat_convert_to_type_sr< bf6_t >(float value, uint32_t seed)
Converts a float to f6_t with saturation and stochastic rounding.
Definition: mxf6_utils.hpp:301
__host__ __device__ f6_t sat_convert_to_type< f6_t >(float value)
Converts a float to f6_t with saturation.
Definition: mxf6_utils.hpp:190
__host__ __device__ bf6_t sat_convert_to_type< bf6_t >(float value)
Converts a float to bf6_t with saturation.
Definition: mxf6_utils.hpp:228
__host__ __device__ float to_float< bf6_t >(e8m0_bexp_t const scale, bf6_t const data)
Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor.
Definition: mxf6_utils.hpp:164
__host__ __device__ bool is_nan< bf6_t >(e8m0_bexp_t const scale, bf6_t const dataBytes[[maybe_unused]])
Checks if an bf6_t value is NaN based on the provided scale.
Definition: mxf6_utils.hpp:40
__host__ __device__ int get_exponent_value< e8m0_bexp_t >(e8m0_bexp_t x)
Definition: e8m0.hpp:73
__host__ __device__ float to_float< f6_t >(e8m0_bexp_t const scale, f6_t const data)
Converts an f6_t value to a float based on an e8m0_bexp_t scale factor.
Definition: mxf6_utils.hpp:138
__host__ __device__ bool is_inf< f6_t >(e8m0_bexp_t const scale[[maybe_unused]], f6_t const data[[maybe_unused]])
Checks if an f6_t value is infinite.
Definition: mxf6_utils.hpp:57
__host__ __device__ bool is_zero< f6_t >(e8m0_bexp_t const scale, f6_t const data)
Checks whether an f6_t value is zero.
Definition: mxf6_utils.hpp:93
__host__ __device__ bool is_nan< f6_t >(e8m0_bexp_t const scale, f6_t const dataBytes[[maybe_unused]])
Checks if an f6_t value is NaN based on the provided scale.
Definition: mxf6_utils.hpp:22
_BitInt(6) f6_t
Definition: data_type.hpp:28
unsigned _BitInt(6) bf6_t
Definition: data_type.hpp:29
Definition: data_type.hpp:2831
Definition: data_type.hpp:3078
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
__host__ constexpr __device__ bool is_nan() const
Definition: e8m0.hpp:64
Definition: mxfp_utils.hpp:9
float value_float
Definition: mxfp_utils.hpp:10
uint32_t value_bitwise
Definition: mxfp_utils.hpp:11