include/ck/utility/mxfp_utils.hpp Source File

include/ck/utility/mxfp_utils.hpp Source File#

Composable Kernel: include/ck/utility/mxfp_utils.hpp Source File
mxfp_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 namespace ck::utils {
7 
8 union cvt
9 {
10  float value_float;
11  uint32_t value_bitwise;
12 };
13 
14 template <typename DTYPE>
15 inline bool getDataHasInf()
16 {
17  return DTYPE::dataInfo.hasInf;
18 }
19 
20 template <typename T>
21 __host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, T const data);
22 
23 template <typename T>
24 __host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, T const data);
25 
26 template <typename T>
27 __host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data);
28 
29 template <typename T>
30 __host__ __device__ inline int get_exponent_value(T x)
31 {
33 
34  x &= ((1 << NumericUtils<T>::exp) - 1);
35 
36  return static_cast<int>(x);
37 }
38 
39 template <typename T>
40 __host__ __device__ inline bool is_subnormal(T x)
41 {
42  return get_exponent_value<T>(x) == 0;
43 }
44 
45 template <typename T>
46 __host__ __device__ inline double get_mantissa_value(T x)
47 {
48  double mantissa = is_subnormal<T>(x) ? 0.0f : 1.0f;
49 
50  for(uint i = 0; i < NumericUtils<T>::mant; i++)
51  {
52 
53  mantissa += std::pow(2, -int32_t((NumericUtils<T>::mant - i))) * (x & 0b1);
54 
55  x >>= 1;
56  }
57 
58  return mantissa;
59 }
60 
61 template <typename T>
62 __host__ __device__ inline bool get_data_has_inf()
63 {
65 }
66 
67 template <typename T>
68 __host__ __device__ float convert_to_float(T data, int scale_exp)
69 {
70  float d_sign =
71  std::pow(-1, static_cast<float>(data >> (NumericUtils<T>::exp + NumericUtils<T>::mant)));
72 
73  float d_exp;
74  if(is_subnormal<T>(data))
75  d_exp = std::pow(2, 1 - static_cast<int>(NumericUtils<T>::bias));
76  else
77  d_exp = std::pow(2, get_exponent_value<T>(data) - static_cast<int>(NumericUtils<T>::bias));
78  float d_mant = get_mantissa_value<T>(data);
79 
80  float data_value = d_sign * d_exp * d_mant;
81  float scale_value = std::pow(
82  2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_bexp_t>::bias))));
83 
84  return data_value * scale_value;
85 }
86 
87 template <typename T>
88 __host__ __device__ inline float to_float(e8m0_bexp_t const scale, T const data);
89 
90 template <typename T>
91 __host__ __device__ T sat_convert_to_type(float value);
92 
93 template <typename T>
94 __host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed);
95 
96 template <typename T>
97 inline T convert_to_type(float value)
98 {
99  using bitwise_type = typename NumericUtils<T>::bitwise_type;
100 
101  if(std::abs(value) > NumericLimits<T>::Max())
102  {
103  float max_value = NumericLimits<T>::Max();
104 
105  cvt t;
106 
107  // cppcheck-suppress redundantAssignment
108  t.value_float = max_value;
109  uint32_t max_bitwise = t.value_bitwise;
110 
111  // cppcheck-suppress redundantAssignment
112  t.value_float = value;
113  bitwise_type sign =
115  bitwise_type exp =
118  bitwise_type mantissa = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
119 
120  uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
121  mant_prev &= ((1 << NumericUtils<T>::mant) - 1);
122  mant_prev--;
123 
125  uint32_t prev_bit =
126  ((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
127 
128  t.value_bitwise = prev_bit;
129  float prev_val = t.value_float;
130  float diff = max_value - prev_val;
131 
132  float actual_max = max_value + (diff / 2);
133 
134  if(std::abs(value) < actual_max)
135  {
136  return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
137  (exp << NumericUtils<T>::mant) | mantissa;
138  }
139  else
140  {
141  if(!get_data_has_inf<T>())
142  {
143 
144  return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
145  }
146  else
147  {
148  exp++;
149  return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
150  (exp << NumericUtils<T>::mant);
151  }
152  }
153  }
154  const int mfmt = NumericUtils<float>::mant;
155  uint32_t x;
156  x = bit_cast<uint32_t>(value);
157 
158  uint32_t head, mantissa;
159  int32_t exponent, bias;
160  uint32_t sign;
161 
163  mantissa = x & NumericUtils<float>::mant_mask;
167 
168  if(x == 0)
169  {
170  return 0b0;
171  }
172 
173  const int mini_bias = NumericUtils<T>::bias;
174  const int mini_denormal_act_exponent = 1 - mini_bias;
175 
176  int act_exponent, out_exponent, exponent_diff;
177 
178  bool is_subnorm = false;
179 
180  if(exponent == 0)
181  {
182  act_exponent = exponent - bias + 1;
183  exponent_diff = mini_denormal_act_exponent - act_exponent;
184  is_subnorm = true;
185  }
186  else
187  {
188  act_exponent = exponent - bias;
189  if(act_exponent <= mini_denormal_act_exponent)
190  {
191  exponent_diff = mini_denormal_act_exponent - act_exponent;
192  is_subnorm = true;
193  }
194  else
195  {
196  exponent_diff = 0;
197  }
198  mantissa += (1UL << mfmt);
199  }
200 
201  auto shift_amount = (mfmt - NumericUtils<T>::mant + exponent_diff);
202  shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
203  bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
204 
205  float min_subnorm = NumericLimits<T>::DataMinSubnorm() * (sign ? -1 : 1);
206 
207  if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
208  {
209  // closer to 0
210  if(std::abs(value) <= std::abs(min_subnorm - value))
211  return 0;
212  else
213  return 1 | (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant));
214  }
215 
216  if(exponent_diff > 0)
217  mantissa >>= exponent_diff;
218  else if(exponent_diff == -1)
219  mantissa <<= -exponent_diff;
220  bool implicit_one = mantissa & (1 << mfmt);
221  out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
222 
223  uint32_t drop_mask = (1UL << (mfmt - NumericUtils<T>::mant)) - 1;
224  bool odd = mantissa & (1UL << (mfmt - NumericUtils<T>::mant));
225  mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
226 
227  if(out_exponent == 0)
228  {
229  if((1UL << mfmt) & mantissa)
230  {
231  out_exponent = 1;
232  }
233  }
234  else
235  {
236  if((1UL << (mfmt + 1)) & mantissa)
237  {
238  mantissa >>= 1;
239  out_exponent++;
240  }
241  }
242 
243  mantissa >>= (mfmt - NumericUtils<T>::mant);
244 
245  if(out_exponent == 0 && mantissa == 0)
246  {
247  return 0;
248  }
249 
250  mantissa &= (1UL << NumericUtils<T>::mant) - 1;
251  return (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant)) |
252  (out_exponent << NumericUtils<T>::mant) | mantissa;
253 }
254 
255 template <typename T>
256 inline T convert_to_type_sr(float value, uint32_t seed)
257 {
258  if(std::abs(value) > NumericLimits<T>::Max())
259  {
260  float max_value = NumericLimits<T>::Max();
261 
262  cvt t;
263 
264  // cppcheck-suppress redundantAssignment
265  t.value_float = max_value;
266  uint max_bitwise = t.value_bitwise;
267 
268  // cppcheck-suppress redundantAssignment
269  t.value_float = value;
273 
274  uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
275  mant_prev &= ((1UL << NumericUtils<T>::mant) - 1);
276  mant_prev--;
277 
279  uint32_t prev_bit =
280  ((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
281 
282  t.value_bitwise = prev_bit;
283  float prev_val = t.value_float;
284  float diff = max_value - prev_val;
285 
286  float actual_max = max_value + (diff / 2);
287 
288  if(std::abs(value) < actual_max)
289  {
290  double d_max_value = static_cast<double>(max_value);
291  double d_actual_max = static_cast<double>(actual_max);
292  double d_value = static_cast<double>(value);
293  double d_is = std::abs(d_max_value - d_actual_max);
294  double d_seed = static_cast<double>(seed);
295  double d_prob = 1.0f - (std::abs(d_value - d_max_value) / d_is); // prob to round down
296 
297  double thresh = UINT_MAX * d_prob;
298 
299  if(!get_data_has_inf<T>() || d_seed <= thresh)
300  // return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
303  else
304  {
305  exp++;
306  return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
307  | (exp << NumericUtils<T>::mant);
308  }
309  }
310  else
311  {
312  if(!get_data_has_inf<T>())
313  return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
314  else
315  {
316  exp++;
317  return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
318  | (exp << NumericUtils<T>::mant);
319  }
320  }
321  }
322 
323  uint32_t f32 = bit_cast<uint32_t>(value);
324 
325  auto f32_mant = f32 & NumericUtils<float>::mant_mask;
326  auto head = f32 & NumericUtils<float>::head_mask;
327  auto f32_exp = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
328 
329  auto sign_bit = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
330  auto sign = sign_bit << (NumericUtils<T>::exp + NumericUtils<T>::mant);
331 
332  f32_exp = static_cast<int32_t>(f32_exp) - NumericUtils<float>::bias;
333  int32_t exp = f32_exp;
334  auto mant = f32_mant;
335  bool subnorm = false;
336 
337  if(f32 == 0)
338  return 0b0;
339 
341  {
342  mant = f32_mant;
343  }
344  // if the exponent bit is 8, then the subnormal is exactly the same as f32
347  {
348  subnorm = true;
349  auto diff = static_cast<uint32_t>(NumericUtils<T>::unbiased_exp_min - exp);
350  if(diff >= 32)
351  {
352  mant = 0;
353  f32_mant = 0;
354  }
355  else
356  {
357  f32_mant |= static_cast<uint32_t>(1) << NumericUtils<float>::mant;
358  f32_mant >>= diff;
359  }
360  exp = 0;
361  mant = f32_mant;
362  }
363 
364  uint32_t sr_shift = NumericUtils<T>::sr_shift;
365 
366  // For stochastic-rounding we add the aligned random value to the
367  // mantissa and then truncate (RTZ).
368  mant += seed >> sr_shift;
369 
370  // Increment exponent when mantissa overflows due to rounding
371  if(mant >= static_cast<uint32_t>(1) << NumericUtils<float>::mant)
372  ++exp;
374  mant &= ((1 << NumericUtils<T>::mant) - 1);
375 
376  auto biased_exp = static_cast<uint32_t>(exp);
377  if(!subnorm)
378  biased_exp = static_cast<uint32_t>(exp + NumericUtils<T>::bias);
379  biased_exp &= ((1 << NumericUtils<T>::exp) - 1);
380  auto val = sign | biased_exp << NumericUtils<T>::mant | mant;
381  return val;
382 }
383 
384 } // namespace ck::utils
__host__ T exp(T x)
Definition: math_v2.hpp:391
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
Definition: check_err.hpp:24
T convert_to_type_sr(float value, uint32_t seed)
Definition: mxfp_utils.hpp:256
__host__ __device__ T sat_convert_to_type(float value)
__host__ __device__ bool is_subnormal(T x)
Definition: mxfp_utils.hpp:40
__host__ __device__ bool get_data_has_inf()
Definition: mxfp_utils.hpp:62
__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed)
__host__ __device__ float convert_to_float(T data, int scale_exp)
Definition: mxfp_utils.hpp:68
T convert_to_type(float value)
Definition: mxfp_utils.hpp:97
__host__ __device__ int get_exponent_value(T x)
Definition: mxfp_utils.hpp:30
__host__ __device__ bool is_zero(e8m0_bexp_t const scale, T const data)
__host__ __device__ bool is_inf(e8m0_bexp_t const scale, T const data)
__host__ __device__ double get_mantissa_value(T x)
Definition: mxfp_utils.hpp:46
__host__ __device__ bool is_nan(e8m0_bexp_t const scale, T const data)
bool getDataHasInf()
Definition: mxfp_utils.hpp:15
__host__ __device__ float to_float(e8m0_bexp_t const scale, T const data)
Definition: data_type.hpp:2831
__host__ static constexpr __device__ T Max()
Definition: data_type.hpp:2833
Definition: data_type.hpp:3078
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: mxfp_utils.hpp:9
float value_float
Definition: mxfp_utils.hpp:10
uint32_t value_bitwise
Definition: mxfp_utils.hpp:11