/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/mxf6_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/mxf6_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/mxf6_utils.hpp Source File
mxf6_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #ifndef CK_CODE_GEN_RTC
5 #pragma once
6 
9 
10 namespace ck::utils {
11 
22 template <>
23 __host__ __device__ inline bool is_nan<f6_t>(e8m0_bexp_t const scale,
24  f6_t const dataBytes [[maybe_unused]])
25 {
26  // no need to check for data as it does not have NaN representation
27  return scale.is_nan();
28 }
29 
40 template <>
41 __host__ __device__ inline bool is_nan<bf6_t>(e8m0_bexp_t const scale,
42  bf6_t const dataBytes [[maybe_unused]])
43 {
44  // no need to check for data as it does not have NaN representation
45  return scale.is_nan();
46 }
47 
57 template <>
58 __host__ __device__ inline bool is_inf<f6_t>(e8m0_bexp_t const scale [[maybe_unused]],
59  f6_t const data [[maybe_unused]])
60 {
61  // no inf representation for fp6
62  return false;
63 }
64 
74 template <>
75 __host__ __device__ inline bool is_inf<bf6_t>(e8m0_bexp_t const scale [[maybe_unused]],
76  bf6_t const data [[maybe_unused]])
77 {
78  // no inf representation for bf6
79  return false;
80 }
81 
93 template <>
94 __host__ __device__ inline bool is_zero<f6_t>(e8m0_bexp_t const scale, f6_t const data)
95 {
96  if(is_nan<f6_t>(scale, data))
97  return false;
98 
99  // no need to check for scale as it does not have a 0 representation
100  f6_t result = (data & 0b00111111) & NumericUtils<f6_t>::set_sign_mask;
101 
102  return result == 0b0;
103 }
104 
116 template <>
117 __host__ __device__ inline bool is_zero<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
118 {
119  if(is_nan<bf6_t>(scale, data))
120  return false;
121 
122  // no need to check for scale as it does not have a 0 representation
123  bf6_t result = (data & 0b00111111) & NumericUtils<bf6_t>::set_sign_mask;
124 
125  return result == 0b0;
126 }
127 
138 template <>
139 __host__ __device__ inline float to_float<f6_t>(e8m0_bexp_t const scale, f6_t const data)
140 {
141  if(is_nan<f6_t>(scale, data))
143 
144  if(is_zero<f6_t>(scale, data))
145  return 0.0f;
146 
147  f6_t prepared_data = data & 0b00111111;
148 
149  int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
150 
151  return convert_to_float<f6_t>(prepared_data, scale_exp);
152 }
153 
164 template <>
165 __host__ __device__ inline float to_float<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
166 {
167  if(is_nan<bf6_t>(scale, data))
169 
170  if(is_zero<bf6_t>(scale, data))
171  return 0.0f;
172 
173  bf6_t prepared_data = data & 0b00111111;
174 
175  int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
176 
177  return convert_to_float<bf6_t>(prepared_data, scale_exp);
178 }
179 
190 template <>
191 __host__ __device__ inline f6_t sat_convert_to_type<f6_t>(float value)
192 {
193  cvt t;
194  t.value_float = value;
195  uint32_t sign = t.value_bitwise >> 31;
196 
197  if(std::isnan(value))
198  {
199 
202  }
203 
204  if(std::abs(value) > NumericLimits<f6_t>::DataMaxNorm()) // covers inf case as well
207 
208  f6_t res = convert_to_type<f6_t>(value);
209 
214 
215  return res;
216 }
217 
228 template <>
229 __host__ __device__ inline bf6_t sat_convert_to_type<bf6_t>(float value)
230 {
231  cvt t;
232  t.value_float = value;
233  uint32_t sign = t.value_bitwise >> 31;
234 
235  if(std::isnan(value))
236  {
237 
240  }
241 
242  if(std::abs(value) > NumericLimits<bf6_t>::DataMaxNorm()) // covers inf case as well
245 
246  bf6_t res = convert_to_type<bf6_t>(value);
247 
252 
253  return res;
254 }
255 
266 template <>
267 __host__ __device__ inline f6_t sat_convert_to_type_sr<f6_t>(float value, uint32_t seed)
268 {
269  cvt t;
270  t.value_float = value;
271  uint32_t sign = t.value_bitwise >> 31;
272 
273  if(std::isnan(value))
276 
277  if(std::abs(value) > NumericLimits<f6_t>::DataMaxNorm()) // covers inf case as well
280 
281  f6_t res = convert_to_type_sr<f6_t>(value, seed);
282 
287 
288  return res;
289 }
290 
301 template <>
302 __host__ __device__ inline bf6_t sat_convert_to_type_sr<bf6_t>(float value, uint32_t seed)
303 {
304  cvt t;
305  t.value_float = value;
306  uint32_t sign = t.value_bitwise >> 31;
307 
308  if(std::isnan(value))
311  if(std::abs(value) > NumericLimits<bf6_t>::DataMaxNorm()) // covers inf case as well
314 
315  bf6_t res = convert_to_type_sr<bf6_t>(value, seed);
316 
321 
322  return res;
323 }
324 } // namespace ck::utils
325 #endif
Definition: check_err.hpp:24
__host__ __device__ bool is_inf< bf6_t >(e8m0_bexp_t const scale[[maybe_unused]], bf6_t const data[[maybe_unused]])
Checks if an bf6_t value is infinite.
Definition: mxf6_utils.hpp:75
__host__ __device__ f6_t sat_convert_to_type_sr< f6_t >(float value, uint32_t seed)
Converts a float to f6_t with saturation and stochastic rounding.
Definition: mxf6_utils.hpp:267
__host__ __device__ bool is_zero< bf6_t >(e8m0_bexp_t const scale, bf6_t const data)
Checks whether an bf6_t value is zero.
Definition: mxf6_utils.hpp:117
__host__ __device__ bf6_t sat_convert_to_type_sr< bf6_t >(float value, uint32_t seed)
Converts a float to f6_t with saturation and stochastic rounding.
Definition: mxf6_utils.hpp:302
__host__ __device__ f6_t sat_convert_to_type< f6_t >(float value)
Converts a float to f6_t with saturation.
Definition: mxf6_utils.hpp:191
__host__ __device__ bf6_t sat_convert_to_type< bf6_t >(float value)
Converts a float to bf6_t with saturation.
Definition: mxf6_utils.hpp:229
__host__ __device__ float to_float< bf6_t >(e8m0_bexp_t const scale, bf6_t const data)
Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor.
Definition: mxf6_utils.hpp:165
__host__ constexpr __device__ int32_t get_exponent_value< e8m0_bexp_t >(e8m0_bexp_t x)
Definition: e8m0.hpp:73
__host__ __device__ bool is_nan< bf6_t >(e8m0_bexp_t const scale, bf6_t const dataBytes[[maybe_unused]])
Checks if an bf6_t value is NaN based on the provided scale.
Definition: mxf6_utils.hpp:41
__host__ __device__ float to_float< f6_t >(e8m0_bexp_t const scale, f6_t const data)
Converts an f6_t value to a float based on an e8m0_bexp_t scale factor.
Definition: mxf6_utils.hpp:139
__host__ __device__ bool is_inf< f6_t >(e8m0_bexp_t const scale[[maybe_unused]], f6_t const data[[maybe_unused]])
Checks if an f6_t value is infinite.
Definition: mxf6_utils.hpp:58
__host__ __device__ bool is_zero< f6_t >(e8m0_bexp_t const scale, f6_t const data)
Checks whether an f6_t value is zero.
Definition: mxf6_utils.hpp:94
__host__ __device__ bool is_nan< f6_t >(e8m0_bexp_t const scale, f6_t const dataBytes[[maybe_unused]])
Checks if an f6_t value is NaN based on the provided scale.
Definition: mxf6_utils.hpp:23
_BitInt(6) f6_t
Definition: data_type.hpp:33
unsigned _BitInt(6) bf6_t
Definition: data_type.hpp:34
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned int uint32_t
Definition: stdint.h:126
Definition: numeric_limits.hpp:309
__host__ static constexpr __device__ T QuietNaN()
Definition: numeric_limits.hpp:313
Definition: numeric_utils.hpp:10
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
__host__ constexpr __device__ bool is_nan() const
Definition: e8m0.hpp:64
Definition: mxfp_utils.hpp:14
float value_float
Definition: mxfp_utils.hpp:15
uint32_t value_bitwise
Definition: mxfp_utils.hpp:16