/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/mxfp_convert.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/mxfp_convert.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/mxfp_convert.hpp Source File
mxfp_convert.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 namespace ck_tile {
7 // modify from include/ck/utility/mxfp_utils.hpp
8 
9 template <typename T>
11 {
12 
15  using raw_type = typename traits::bitwise_type;
16 
17  static constexpr int exp_mask = (1 << traits::exp) - 1;
18 
19  static constexpr raw_type get_exponent(raw_type x)
20  {
21  // TODO: check if repeated calls are optimized.
22  return (x >> traits::mant) & exp_mask;
23  }
24  static constexpr raw_type get_exponent(const T& x)
25  {
26  return get_exponent(bit_cast<raw_type>(x));
27  }
28  static constexpr bool is_positive(raw_type x)
29  {
30  return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero;
31  }
32  static constexpr bool is_subnormal(raw_type x)
33  {
34  return get_exponent(x) == _numeric::binary_zero;
35  }
36  // TODO: replace double with template arg?
37  static constexpr double get_mantissa(raw_type x)
38  {
39  double mantissa = is_subnormal(x) ? 0.0f : 1.0f;
40  for(raw_type i = 0; i < traits::mant; ++i)
41  {
42  mantissa += std::ldexp(static_cast<float>(x & 0b1), -(traits::mant - i));
43  x >>= 1;
44  }
45  return mantissa;
46  }
47 };
48 
49 template <typename T>
50 CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale = 1.f)
51 {
52  using utils = numeric_utils<T>;
53  float sign = utils::is_positive(data) ? 1.0 : -1.0;
54  int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
55  float mant = utils::get_mantissa(data);
56 
57  return std::ldexp(sign * mant * scale, exp);
58 }
59 
60 template <typename T>
61 CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value, float scale = 1.f)
62 {
63  using bitwise_type = typename numeric_traits<T>::bitwise_type;
64 
65  value /= scale;
66 
67  if(std::abs(value) > float(numeric<T>::max()))
68  {
69  float max_value = numeric<T>::max();
70 
71  // cppcheck-suppress redundantAssignment
72  uint32_t max_bitwise = bit_cast<uint32_t>(max_value);
73 
74  // cppcheck-suppress redundantAssignment
75  bitwise_type sign =
77  bitwise_type exp =
80  bitwise_type mantissa =
82 
83  uint32_t mant_prev = max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
84  mant_prev &= ((1 << numeric_traits<T>::mant) - 1);
85  mant_prev--;
86 
88  uint32_t prev_bit =
90  mant_prev;
91 
92  float prev_val = bit_cast<float>(prev_bit);
93  float diff = max_value - prev_val;
94 
95  float actual_max = max_value + (diff / 2);
96 
97  if(std::abs(value) < actual_max)
98  {
99  return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
100  (exp << numeric_traits<T>::mant) | mantissa;
101  }
102  else
103  {
104  if constexpr(!numeric<T>::has_inf())
105  {
106 
107  return (1 << (numeric_traits<T>::mant + numeric_traits<T>::exp)) - 1;
108  }
109  else
110  {
111  exp++;
112  return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
113  (exp << numeric_traits<T>::mant);
114  }
115  }
116  }
117  const int mfmt = numeric_traits<float>::mant;
118  uint32_t x;
119  x = bit_cast<uint32_t>(value);
120 
121  uint32_t head, mantissa;
122  int32_t exponent, bias;
123  uint32_t sign;
124 
126  mantissa = x & numeric_traits<float>::mant_mask;
130 
131  if(x == 0)
132  {
133  return 0b0;
134  }
135 
136  const int mini_bias = numeric_traits<T>::bias;
137  const int mini_denormal_act_exponent = 1 - mini_bias;
138 
139  int act_exponent, out_exponent, exponent_diff;
140 
141  bool is_subnorm = false;
142 
143  if(exponent == 0)
144  {
145  act_exponent = exponent - bias + 1;
146  exponent_diff = mini_denormal_act_exponent - act_exponent;
147  is_subnorm = true;
148  }
149  else
150  {
151  act_exponent = exponent - bias;
152  if(act_exponent <= mini_denormal_act_exponent)
153  {
154  exponent_diff = mini_denormal_act_exponent - act_exponent;
155  is_subnorm = true;
156  }
157  else
158  {
159  exponent_diff = 0;
160  }
161  mantissa += (1UL << mfmt);
162  }
163 
164  auto shift_amount = (mfmt - numeric_traits<T>::mant + exponent_diff);
165  shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
166  bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
167 
168  float min_subnorm = float(numeric<T>::epsilon()) * (sign ? -1 : 1);
169 
170  if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
171  {
172  // closer to 0
173  if(std::abs(value) <= std::abs(min_subnorm - value))
175  else
176  return 1 | (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant));
177  }
178 
179  if(exponent_diff > 0)
180  mantissa >>= exponent_diff;
181  else if(exponent_diff == -1)
182  mantissa <<= -exponent_diff;
183  bool implicit_one = mantissa & (1 << mfmt);
184  out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
185 
186  uint32_t drop_mask = (1UL << (mfmt - numeric_traits<T>::mant)) - 1;
187  bool odd = mantissa & (1UL << (mfmt - numeric_traits<T>::mant));
188  mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
189 
190  if(out_exponent == 0)
191  {
192  if((1UL << mfmt) & mantissa)
193  {
194  out_exponent = 1;
195  }
196  }
197  else
198  {
199  if((1UL << (mfmt + 1)) & mantissa)
200  {
201  mantissa >>= 1;
202  out_exponent++;
203  }
204  }
205 
206  mantissa >>= (mfmt - numeric_traits<T>::mant);
207 
208  if(out_exponent == 0 && mantissa == 0)
209  {
211  }
212 
213  mantissa &= (1UL << numeric_traits<T>::mant) - 1;
214  return (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant)) |
215  (out_exponent << numeric_traits<T>::mant) | mantissa;
216 }
217 
218 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ T exp(T x)
Definition: math_v2.hpp:391
__host__ __device__ bool is_subnormal(T x)
Definition: mxfp_utils.hpp:45
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale=1.f)
Definition: mxfp_convert.hpp:50
CK_TILE_HOST_DEVICE T::raw_type convert_to_type(float value, float scale=1.f)
Definition: mxfp_convert.hpp:61
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:423
int32_t int32_t
Definition: integer.hpp:10
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned int uint32_t
Definition: stdint.h:126
Definition: numeric.hpp:81
Definition: mxfp_convert.hpp:11
static constexpr bool is_positive(raw_type x)
Definition: mxfp_convert.hpp:28
static constexpr raw_type get_exponent(raw_type x)
Definition: mxfp_convert.hpp:19
static constexpr double get_mantissa(raw_type x)
Definition: mxfp_convert.hpp:37
static constexpr int exp_mask
Definition: mxfp_convert.hpp:17
typename traits::bitwise_type raw_type
Definition: mxfp_convert.hpp:15
static constexpr bool is_subnormal(raw_type x)
Definition: mxfp_convert.hpp:32
static constexpr raw_type get_exponent(const T &x)
Definition: mxfp_convert.hpp:24
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26