/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/mxfp_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/mxfp_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/mxfp_utils.hpp Source File
mxfp_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 #ifdef CK_CODE_GEN_RTC
9 #define UINT_MAX 4294967295
10 #endif
11 namespace ck::utils {
12 
13 union cvt
14 {
15  float value_float;
17 };
18 
19 template <typename DTYPE>
20 inline bool getDataHasInf()
21 {
22  return DTYPE::dataInfo.hasInf;
23 }
24 
25 template <typename T>
26 __host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, T const data);
27 
28 template <typename T>
29 __host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, T const data);
30 
31 template <typename T>
32 __host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data);
33 
34 template <typename T>
35 __host__ __device__ inline constexpr int32_t get_exponent_value(T x)
36 {
38 
39  x &= ((1 << NumericUtils<T>::exp) - 1);
40 
41  return static_cast<int32_t>(x);
42 }
43 
44 template <typename T>
45 __host__ __device__ inline bool is_subnormal(T x)
46 {
47  return get_exponent_value<T>(x) == 0;
48 }
49 
50 template <typename T>
51 __host__ __device__ inline double get_mantissa_value(T x)
52 {
53  double mantissa = is_subnormal<T>(x) ? 0.0f : 1.0f;
54 
55  for(uint i = 0; i < NumericUtils<T>::mant; i++)
56  {
57 
58  mantissa += std::pow(2, -int32_t((NumericUtils<T>::mant - i))) * (x & 0b1);
59 
60  x >>= 1;
61  }
62 
63  return mantissa;
64 }
65 
66 template <typename T>
67 __host__ __device__ inline bool get_data_has_inf()
68 {
70 }
71 
72 template <typename T>
73 __host__ __device__ float convert_to_float(T data, int scale_exp)
74 {
75  float d_sign =
76  std::pow(-1, static_cast<float>(data >> (NumericUtils<T>::exp + NumericUtils<T>::mant)));
77 
78  float d_exp;
79  if(is_subnormal<T>(data))
80  d_exp = std::pow(2, 1 - static_cast<int>(NumericUtils<T>::bias));
81  else
82  d_exp = std::pow(2, get_exponent_value<T>(data) - static_cast<int>(NumericUtils<T>::bias));
83  float d_mant = get_mantissa_value<T>(data);
84 
85  float data_value = d_sign * d_exp * d_mant;
86  float scale_value = std::pow(
87  2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_bexp_t>::bias))));
88 
89  return data_value * scale_value;
90 }
91 
92 template <typename T>
93 __host__ __device__ inline float to_float(e8m0_bexp_t const scale, T const data);
94 
95 template <typename T>
96 __host__ __device__ T sat_convert_to_type(float value);
97 
98 template <typename T>
99 __host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed);
100 
101 template <typename T>
102 __host__ __device__ inline T convert_to_type(float value)
103 {
104  using bitwise_type = typename NumericUtils<T>::bitwise_type;
105 
106  if(std::abs(value) > NumericLimits<T>::Max())
107  {
108  float max_value = NumericLimits<T>::Max();
109 
110  cvt t;
111 
112  // cppcheck-suppress redundantAssignment
113  t.value_float = max_value;
114  uint32_t max_bitwise = t.value_bitwise;
115 
116  // cppcheck-suppress redundantAssignment
117  t.value_float = value;
118  bitwise_type sign =
120  bitwise_type exp =
123  bitwise_type mantissa = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
124 
125  uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
126  mant_prev &= ((1 << NumericUtils<T>::mant) - 1);
127  mant_prev--;
128 
130  uint32_t prev_bit =
131  ((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
132 
133  t.value_bitwise = prev_bit;
134  float prev_val = t.value_float;
135  float diff = max_value - prev_val;
136 
137  float actual_max = max_value + (diff / 2);
138 
139  if(std::abs(value) < actual_max)
140  {
141  return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
142  (exp << NumericUtils<T>::mant) | mantissa;
143  }
144  else
145  {
146  if(!get_data_has_inf<T>())
147  {
148 
149  return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
150  }
151  else
152  {
153  exp++;
154  return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
155  (exp << NumericUtils<T>::mant);
156  }
157  }
158  }
159  const int mfmt = NumericUtils<float>::mant;
160  uint32_t x;
161  x = bit_cast<uint32_t>(value);
162 
163  uint32_t head, mantissa;
164  int32_t exponent, bias;
165  uint32_t sign;
166 
168  mantissa = x & NumericUtils<float>::mant_mask;
172 
173  if(x == 0)
174  {
175  return 0b0;
176  }
177 
178  const int mini_bias = NumericUtils<T>::bias;
179  const int mini_denormal_act_exponent = 1 - mini_bias;
180 
181  int act_exponent, out_exponent, exponent_diff;
182 
183  bool is_subnorm = false;
184 
185  if(exponent == 0)
186  {
187  act_exponent = exponent - bias + 1;
188  exponent_diff = mini_denormal_act_exponent - act_exponent;
189  is_subnorm = true;
190  }
191  else
192  {
193  act_exponent = exponent - bias;
194  if(act_exponent <= mini_denormal_act_exponent)
195  {
196  exponent_diff = mini_denormal_act_exponent - act_exponent;
197  is_subnorm = true;
198  }
199  else
200  {
201  exponent_diff = 0;
202  }
203  mantissa += (1UL << mfmt);
204  }
205 
206  auto shift_amount = (mfmt - NumericUtils<T>::mant + exponent_diff);
207  shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
208  bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
209 
210  float min_subnorm = NumericLimits<T>::DataMinSubnorm() * (sign ? -1 : 1);
211 
212  if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
213  {
214  // closer to 0
215  if(std::abs(value) <= std::abs(min_subnorm - value))
216  return sign << (NumericUtils<T>::exp + NumericUtils<T>::mant);
217  else
218  return 1 | (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant));
219  }
220 
221  if(exponent_diff > 0)
222  mantissa >>= exponent_diff;
223  else if(exponent_diff == -1)
224  mantissa <<= -exponent_diff;
225  bool implicit_one = mantissa & (1 << mfmt);
226  out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
227 
228  uint32_t drop_mask = (1UL << (mfmt - NumericUtils<T>::mant)) - 1;
229  bool odd = mantissa & (1UL << (mfmt - NumericUtils<T>::mant));
230  mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
231 
232  if(out_exponent == 0)
233  {
234  if((1UL << mfmt) & mantissa)
235  {
236  out_exponent = 1;
237  }
238  }
239  else
240  {
241  if((1UL << (mfmt + 1)) & mantissa)
242  {
243  mantissa >>= 1;
244  out_exponent++;
245  }
246  }
247 
248  mantissa >>= (mfmt - NumericUtils<T>::mant);
249 
250  if(out_exponent == 0 && mantissa == 0)
251  {
252  return sign << (NumericUtils<T>::exp + NumericUtils<T>::mant);
253  }
254 
255  mantissa &= (1UL << NumericUtils<T>::mant) - 1;
256  return (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant)) |
257  (out_exponent << NumericUtils<T>::mant) | mantissa;
258 }
259 
260 template <typename T>
261 __host__ __device__ inline T convert_to_type_sr(float value, uint32_t seed)
262 {
263  if(std::abs(value) > NumericLimits<T>::Max())
264  {
265  float max_value = NumericLimits<T>::Max();
266 
267  cvt t;
268 
269  // cppcheck-suppress redundantAssignment
270  t.value_float = max_value;
271  uint max_bitwise = t.value_bitwise;
272 
273  // cppcheck-suppress redundantAssignment
274  t.value_float = value;
278 
279  uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
280  mant_prev &= ((1UL << NumericUtils<T>::mant) - 1);
281  mant_prev--;
282 
284  uint32_t prev_bit =
285  ((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
286 
287  t.value_bitwise = prev_bit;
288  float prev_val = t.value_float;
289  float diff = max_value - prev_val;
290 
291  float actual_max = max_value + (diff / 2);
292 
293  if(std::abs(value) < actual_max)
294  {
295  double d_max_value = static_cast<double>(max_value);
296  double d_actual_max = static_cast<double>(actual_max);
297  double d_value = static_cast<double>(value);
298  double d_is = std::abs(d_max_value - d_actual_max);
299  double d_seed = static_cast<double>(seed);
300  double d_prob = 1.0f - (std::abs(d_value - d_max_value) / d_is); // prob to round down
301 
302  double thresh = UINT_MAX * d_prob;
303 
304  if(!get_data_has_inf<T>() || d_seed <= thresh)
305  // return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
308  else
309  {
310  exp++;
311  return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
312  | (exp << NumericUtils<T>::mant);
313  }
314  }
315  else
316  {
317  if(!get_data_has_inf<T>())
318  return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
319  else
320  {
321  exp++;
322  return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
323  | (exp << NumericUtils<T>::mant);
324  }
325  }
326  }
327 
328  uint32_t f32 = bit_cast<uint32_t>(value);
329 
330  auto f32_mant = f32 & NumericUtils<float>::mant_mask;
331  auto head = f32 & NumericUtils<float>::head_mask;
332  auto f32_exp = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
333 
334  auto sign_bit = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
335  auto sign = sign_bit << (NumericUtils<T>::exp + NumericUtils<T>::mant);
336 
337  f32_exp = static_cast<int32_t>(f32_exp) - NumericUtils<float>::bias;
338  int32_t exp = f32_exp;
339  auto mant = f32_mant;
340  bool subnorm = false;
341 
342  if(f32 == 0)
343  return 0b0;
344 
346  {
347  mant = f32_mant;
348  }
349  // if the exponent bit is 8, then the subnormal is exactly the same as f32
352  {
353  subnorm = true;
354  auto diff = static_cast<uint32_t>(NumericUtils<T>::unbiased_exp_min - exp);
355  if(diff >= 32)
356  {
357  mant = 0;
358  f32_mant = 0;
359  }
360  else
361  {
362  f32_mant |= static_cast<uint32_t>(1) << NumericUtils<float>::mant;
363  f32_mant >>= diff;
364  }
365  exp = 0;
366  mant = f32_mant;
367  }
368 
370 
371  // For stochastic-rounding we add the aligned random value to the
372  // mantissa and then truncate (RTZ).
373  mant += seed >> sr_shift;
374 
375  // Increment exponent when mantissa overflows due to rounding
376  if(mant >= static_cast<uint32_t>(1) << NumericUtils<float>::mant)
377  ++exp;
379  mant &= ((1 << NumericUtils<T>::mant) - 1);
380 
381  auto biased_exp = static_cast<uint32_t>(exp);
382  if(!subnorm)
383  biased_exp = static_cast<uint32_t>(exp + NumericUtils<T>::bias);
384  biased_exp &= ((1 << NumericUtils<T>::exp) - 1);
385  auto val = sign | biased_exp << NumericUtils<T>::mant | mant;
386  return val;
387 }
388 } // namespace ck::utils
__host__ T exp(T x)
Definition: math_v2.hpp:391
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
Definition: check_err.hpp:24
__host__ __device__ T sat_convert_to_type(float value)
__host__ __device__ bool is_subnormal(T x)
Definition: mxfp_utils.hpp:45
__host__ __device__ bool get_data_has_inf()
Definition: mxfp_utils.hpp:67
__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed)
__host__ __device__ float convert_to_float(T data, int scale_exp)
Definition: mxfp_utils.hpp:73
__host__ __device__ T convert_to_type_sr(float value, uint32_t seed)
Definition: mxfp_utils.hpp:261
__host__ __device__ bool is_zero(e8m0_bexp_t const scale, T const data)
__host__ __device__ T convert_to_type(float value)
Definition: mxfp_utils.hpp:102
__host__ __device__ bool is_inf(e8m0_bexp_t const scale, T const data)
__host__ constexpr __device__ int32_t get_exponent_value(T x)
Definition: mxfp_utils.hpp:35
__host__ __device__ double get_mantissa_value(T x)
Definition: mxfp_utils.hpp:51
__host__ __device__ bool is_nan(e8m0_bexp_t const scale, T const data)
bool getDataHasInf()
Definition: mxfp_utils.hpp:20
__host__ __device__ float to_float(e8m0_bexp_t const scale, T const data)
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned int uint32_t
Definition: stdint.h:126
signed int int32_t
Definition: stdint.h:123
Definition: numeric_limits.hpp:309
__host__ static constexpr __device__ T Max()
Definition: numeric_limits.hpp:311
Definition: numeric_utils.hpp:10
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: mxfp_utils.hpp:14
float value_float
Definition: mxfp_utils.hpp:15
uint32_t value_bitwise
Definition: mxfp_utils.hpp:16