/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/float8.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/float8.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/float8.hpp Source File
float8.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
12 #include <stdint.h>
13 #include <type_traits>
14 
15 #pragma once
16 
17 #if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
18 #define CK_TILE_FP8_CVT_DEVICE 1
19 #else
20 #define CK_TILE_FP8_CVT_DEVICE 0
21 #endif
22 
23 namespace ck_tile {
24 
25 // fp8 rounding modes
26 // use standard for rounding to nearest, the faster one
27 // use stochastic for stochastic rounding, helps to avoid error accumulation
29 {
30  standard = 0,
32 };
33 
38 {
39  E4M3_OCP = 0, // OCP FP8 E4M3
40  E5M2_OCP = 1, // OCP BF8 E5M2
41  E4M3_FNUZ = 2, // FNUZ FP8 E4M3
42  E5M2_FNUZ = 3, // FNUZ BF8 E5M2
43 };
44 
45 /*
46  * ______________FNUZ_________________ | ______________OCP________________
47  * e4m3 e5m2 | e4m3 e5m2
48  * bias : 8 16 | 7 15
49  * inf : N/A N/A | N/A s.11111.00
50  * Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
51  * zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
52  * Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
53  * Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11
54  * 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
55  * Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
56  * 2^-7(0.0078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
57  * Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
58  * 2^-10(0.0009765625) 2^-17(7.62939e-06) | 2^-9(0.001953125) 2^-16(1.52588e-05)
59  */
60 
61 template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
62 CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant<rounding> = {});
63 
64 template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
65 CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant<rounding> = {});
66 
69 
70 #if CK_TILE_USE_CUSTOM_DATA_TYPE
71 struct alignas(1) float8_e4m3_t
72 {
73  static constexpr int exponent = 4;
74  static constexpr int mantissa = 3;
75 #if CK_TILE_USE_OCP_FP8
76  static constexpr int bias = 7; // OCP
77 #else
78  static constexpr int bias = 8; // FNUZ
79 #endif
80  using raw_type = uint8_t;
81  raw_type data;
82 
84  static constexpr float8_e4m3_t bit_cast(raw_type x)
85  {
86  float8_e4m3_t y;
87  y.data = x;
88  return y;
89  }
90 
91  // constructor
92  constexpr float8_e4m3_t() : data() {}
93 
94  // construct from float
96  explicit constexpr float8_e4m3_t(const float& x) : data(float_to_fp8_raw(x)) {}
97 
98  // construct from int
100  explicit constexpr float8_e4m3_t(const int& x) : data(float_to_fp8_raw(static_cast<float>(x)))
101  {
102  }
103 
104  // construct from unsigned int
106  explicit constexpr float8_e4m3_t(const unsigned int& x)
107  : data(float_to_fp8_raw(static_cast<float>(x)))
108  {
109  }
110 
111  // cast to float
113  explicit constexpr operator float() const { return fp8_to_float_raw(data); }
114 
115  // cast to int
117  explicit constexpr operator int() const { return static_cast<int>(fp8_to_float_raw(data)); }
118 
119  // internal access
121  constexpr raw_type& get() { return data; }
122 
124  constexpr raw_type get() const { return data; }
125 };
126 using fp8_t = float8_e4m3_t;
127 using fp8_raw_t = typename fp8_t::raw_type;
128 
129 struct alignas(1) float8_e5m2_t
130 {
131  static constexpr int exponent = 5;
132  static constexpr int mantissa = 2;
133 #if CK_TILE_USE_OCP_FP8
134  static constexpr int bias = 15; // OCP
135 #else
136  static constexpr int bias = 16; // FNUZ
137 #endif
138  using raw_type = uint8_t;
139  raw_type data;
140 
142  static constexpr float8_e5m2_t bit_cast(raw_type x)
143  {
144  float8_e5m2_t y;
145  y.data = x;
146  return y;
147  }
148 
149  // constructor
150  constexpr float8_e5m2_t() : data() {}
151 
152  // construct from float
154  explicit constexpr float8_e5m2_t(const float& x) : data(float_to_bf8_raw(x)) {}
155 
156  // construct from int
158  explicit constexpr float8_e5m2_t(const int& x) : data(float_to_bf8_raw(static_cast<float>(x)))
159  {
160  }
161 
162  // construct from unsigned int
164  explicit constexpr float8_e5m2_t(const unsigned int& x)
165  : data(float_to_bf8_raw(static_cast<float>(x)))
166  {
167  }
168 
169  // cast to float
171  explicit constexpr operator float() const { return bf8_to_float_raw(data); }
172 
173  // cast to int
175  explicit constexpr operator int() const { return static_cast<int>(bf8_to_float_raw(data)); }
176 
177  // internal access
179  constexpr raw_type& get() { return data; }
180 
182  constexpr raw_type get() const { return data; }
183 };
184 using bf8_t = float8_e5m2_t;
185 using bf8_raw_t = typename bf8_t::raw_type;
186 
187 template <typename>
188 struct native_t;
189 
190 template <>
191 struct native_t<fp8_t>
192 {
193  using type = _BitInt(8);
194 };
195 
196 template <>
197 struct native_t<bf8_t>
198 {
199  using type = unsigned _BitInt(8);
200 };
201 
202 #else
203 
204 using fp8_t = _BitInt(8);
206 using bf8_t = unsigned _BitInt(8);
208 #endif
209 
210 template <>
212 {
214 
215  static constexpr int exp = 4;
216  static constexpr int mant = 3;
217 #if CK_TILE_USE_OCP_FP8
218  static constexpr int bias = 7;
219  static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_OCP;
220 #else
221  static constexpr int bias = 8;
222  static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_FNUZ;
223 #endif
224  static constexpr uint8_t abs_mask = 0x7F;
225  static constexpr int PackedSize = 1;
226 };
227 
228 template <>
230 {
232 
233  static constexpr int exp = 5;
234  static constexpr int mant = 2;
235 #if CK_TILE_USE_OCP_FP8
236  static constexpr int bias = 15;
237  static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_OCP;
238 #else
239  static constexpr int bias = 16;
240  static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_FNUZ;
241 #endif
242  static constexpr uint8_t abs_mask = 0x7F;
243  static constexpr int PackedSize = 1;
244 };
245 
246 // below is sw fp8 conversion, not utilizing hw instruction
247 namespace impl {
248 
249 template <typename SrcT, typename DstT, bool clip = true, bool stoch = false>
250 CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
251 {
253  "DstT type must be fp8 or bf8.");
254 
255  constexpr bool is_half = std::is_same<SrcT, half_t>::value;
256  constexpr bool is_float = std::is_same<SrcT, float>::value;
257  static_assert(is_half || is_float, "Only half and float can be cast to f8");
258 
259  // fp8/bf8 type exponent/mantissa layout
260  constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
261  constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
262  constexpr int DstT_bias = numeric_traits<DstT>::bias;
263  constexpr bool is_fnuz =
266 
267  constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
268  constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
269  constexpr int bias = numeric_traits<SrcT>::bias;
270  constexpr unsigned int fInf = numeric_traits<SrcT>::Inf;
271  constexpr unsigned int abs_mask = numeric_traits<SrcT>::abs_mask;
272 
273  using SrcT_bitwise = typename numeric_traits<SrcT>::bitwise_type;
274  SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
275 
276  unsigned int head, mantissa;
277  int exponent;
278  unsigned int sign;
279 
280  head = src_bitwise & numeric_traits<SrcT>::head_mask;
281  mantissa = src_bitwise & numeric_traits<SrcT>::mant_mask;
282  exponent = (head >> SrcT_mant) & numeric_traits<SrcT>::exp_mask;
283  sign = head >> (SrcT_exp + SrcT_mant);
284 
285  unsigned int signed_inf = 0;
286  unsigned int nan = 0;
287  if constexpr(is_fnuz)
288  {
289  signed_inf = clip ? ((sign << (DstT_exp + DstT_mant)) + 0x7f) : 0x80;
290  nan = 0x80;
291  }
292  else
293  {
294  if constexpr(DstT_exp == 4)
295  { // e4m3
296  signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7e : 0x7f);
297  }
298  else
299  { // e5m2
300  signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7b : 0x7c);
301  }
302  nan = (sign << (DstT_exp + DstT_mant)) + 0x7f;
303  }
304  // Max values
305  unsigned int ifmax = 0;
306  if constexpr(is_float)
307  {
308  if constexpr(DstT_exp == 5)
309  {
310  ifmax = 0x47600000;
311  }
312  else
313  {
314  if constexpr(is_fnuz)
315  {
316  ifmax = 0x43700000;
317  }
318  else
319  {
320  ifmax = 0x43E00000;
321  }
322  }
323  }
324  else if constexpr(is_half)
325  {
326  if constexpr(DstT_exp == 5)
327  {
328  ifmax = 0x7B00;
329  }
330  else
331  {
332  if constexpr(is_fnuz)
333  {
334  ifmax = 0x5B80;
335  }
336  else
337  {
338  ifmax = 0x5F00;
339  }
340  }
341  }
342 
343  // Deal with inf and NaNs
344  if((src_bitwise & fInf) == fInf)
345  {
346  return mantissa != 0 ? nan : signed_inf;
347  }
348 
349  if((src_bitwise & abs_mask) > ifmax)
350  {
351  return signed_inf;
352  }
353 
354  // First need to check if it is normal or denorm as there is a difference of
355  // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
356  // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
357  // to mantissa and truncate. And for RNE, no need to add rng. Then probably
358  // need to check whether there is carry and adjust exponent and mantissa again
359 
360  // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
361  // bits
362  constexpr int f8_denormal_act_exponent = 1 - DstT_bias; // actual exponent of f8 denormal
363  // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
364  // f8_exponent is the converted f8 exponent with bias encoding
365  // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
366  // the difference needs to be adjusted and mantissa shifted
367  int act_exponent, f8_exponent, exponent_diff;
368 
369  if(exponent == 0)
370  { // fp32/fp16 is in denormal.
371  /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
372  mostly concern fp16 here. In this case, f8 is usually in denormal. But there
373  could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
374  exponent bias 16. It means that there are some numbers in fp16 denormal but they
375  are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
376  where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
377  (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
378  act_exponent = exponent - bias + 1;
379  exponent_diff = f8_denormal_act_exponent -
380  act_exponent; // actual exponent is exponent-bias+1 as it is denormal
381  }
382  else
383  { // fp32/fp16 is normal with implicit 1
384  act_exponent = exponent - bias;
385  if(act_exponent <= f8_denormal_act_exponent)
386  {
387  /* This is the case where fp32/fp16 is normal but it is in f8 denormal
388  range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
389  actual exponent is -7, it is actually larger due to the implicit 1,
390  Therefore it needs to be adjust to -6 and mantissa shift right by 1.
391  So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
392  exponent_diff = f8_denormal_act_exponent - act_exponent;
393  }
394  else
395  { // both fp32/fp16 and f8 are in normal range
396  exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
397  // for this case, act_exponent could be larger. Just
398  // that it does not need shift mantissa
399  }
400  mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa
401  }
402  // The value is smaller than min f8 denormal and results in zero (the early exit also prevents
403  // an undefined behavior of bit shifts >= type width).
404  if(exponent_diff > DstT_mant)
405  {
406  return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
407  }
408  bool midpoint = (mantissa & ((1u << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
409  (1u << (SrcT_mant - DstT_mant + exponent_diff - 1));
410  /* This part is a bit tricky. The judgment of whether it is a tie needs to be
411  done before we shift right as shift right could rip off some residual part and
412  make something not midpoint look like midpoint. For example, the fp16 number
413  0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
414  by 4 bits, it would look like midpoint.
415  */
416 
417  if(exponent_diff > 0)
418  mantissa >>= exponent_diff;
419  else if(exponent_diff == -1)
420  mantissa <<= -exponent_diff;
421  bool implicit_one = mantissa & (1u << SrcT_mant);
422  // if there is no implicit 1, it means the f8 is denormal and need to adjust
423  // to denorm exponent
424  f8_exponent =
425  (act_exponent + exponent_diff) /*actual f8 exponent*/ + DstT_bias - (implicit_one ? 0 : 1);
426 
427  // Now we have the exponent and mantissa adjusted
428  unsigned int drop_mask = (1u << (SrcT_mant - DstT_mant)) - 1;
429  bool odd =
430  mantissa &
431  (1u << (SrcT_mant - DstT_mant)); // if the least significant bit that is not truncated is 1
432  mantissa +=
433  (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1u) : mantissa)) & drop_mask;
434 
435  // Now we deal with overflow
436  if(f8_exponent == 0)
437  {
438  if((1u << SrcT_mant) & mantissa)
439  {
440  f8_exponent = 1; // denormal overflow to become normal, promote exponent
441  }
442  }
443  else
444  {
445  if((1u << (SrcT_mant + 1)) & mantissa)
446  {
447  mantissa >>= 1;
448  f8_exponent++;
449  }
450  }
451 
452  mantissa >>= (SrcT_mant - DstT_mant);
453 
454  // above range: quantize to maximum possible float of the same sign
455  const int max_exp = (1 << DstT_exp) - 1;
456  if(f8_exponent > max_exp)
457  {
458  if constexpr(clip)
459  {
460  mantissa = (1 << DstT_mant) - 1;
461  f8_exponent = max_exp;
462  }
463  else
464  {
465  return signed_inf;
466  }
467  }
468 
469  if(f8_exponent == 0 && mantissa == 0)
470  return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
471  mantissa &= (1 << DstT_mant) - 1;
472  return (sign << (DstT_exp + DstT_mant)) | (f8_exponent << DstT_mant) | mantissa;
473 }
474 
475 template <typename SrcT, typename DstT, bool clip = true>
477 {
479  "SrcT type must be fp8 or bf8.");
480  constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
481  constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
482  constexpr uint8_t SrcT_abs_mask = numeric_traits<SrcT>::abs_mask;
483  constexpr bool is_fnuz =
486 
487  constexpr bool is_half = std::is_same<DstT, half_t>::value;
488  constexpr bool is_float = std::is_same<DstT, float>::value;
489  static_assert(is_half || is_float, "DstT type must be half_t or float.");
490 
491  // destination type exponent/mantissa layout
492  constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
493  constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
494 
495  constexpr DstT fInf = bit_cast<DstT>(numeric_traits<DstT>::Inf);
496  constexpr DstT fNegInf = bit_cast<DstT>(numeric_traits<DstT>::NegInf);
497  constexpr DstT fNaN = bit_cast<DstT>(numeric_traits<DstT>::NaN);
498  constexpr DstT fNeg0 = bit_cast<DstT>(numeric_traits<DstT>::Neg0);
499 
500  DstT fmax{0}, fmin{0};
501  // Max number in e5m2 57344
502  if constexpr(is_half)
503  {
504  fmax = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0x7B00));
505  fmin = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0xFB00));
506  }
507  else if constexpr(is_float)
508  {
509  fmax = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0x47600000));
510  fmin = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0xC7600000));
511  }
512 
513  if(x == 0)
514  {
515  return 0;
516  }
517 
518  unsigned int sign = x >> (SrcT_exp + SrcT_mant);
519  unsigned int mantissa = x & ((1 << SrcT_mant) - 1);
520  int exponent = (x & SrcT_abs_mask) >> SrcT_mant;
521  if constexpr(is_fnuz)
522  {
523  if((x & 0xff) == 0x80)
524  {
525  return fNaN;
526  }
527  }
528  else
529  {
530  if(x == SrcT(0x80))
531  {
532  return fNeg0;
533  }
534  if constexpr(SrcT_exp == 4)
535  { // e4m3
536  if((x & 0x7F) == 0x7F)
537  {
538  return fNaN;
539  }
540  }
541  else if((x & 0x7C) == 0x7C)
542  { // e5m2
543  if((x & 0x3) == 0)
544  {
545  if constexpr(clip)
546  {
547  return sign ? fmin : fmax;
548  }
549  return sign ? fNegInf : fInf;
550  }
551  return fNaN;
552  }
553  }
554 
555  typename numeric_traits<DstT>::bitwise_type retval;
556 
557  if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
558  {
559  retval = static_cast<typename numeric_traits<DstT>::bitwise_type>(x) << 8;
560  return bit_cast<DstT>(retval);
561  }
562 
563  const int exp_low_cutoff =
564  (1 << (DstT_exp - 1)) - (1 << (SrcT_exp - 1)) + 1 - (is_fnuz ? 1 : 0);
565 
566  // subnormal input
567  if(exponent == 0)
568  {
569  int sh = 1 + clz(mantissa) - (32 - SrcT_mant);
570  mantissa <<= sh;
571  exponent += 1 - sh;
572  mantissa &= ((1ull << SrcT_mant) - 1);
573  }
574  exponent += exp_low_cutoff - 1;
575  mantissa <<= DstT_mant - SrcT_mant;
576 
577  // subnormal output (occurs when DstT is half_t, we=5, is_fnuz=true)
578  if(exponent <= 0)
579  {
580  mantissa |= 1 << DstT_mant;
581  mantissa >>= 1 - exponent;
582  exponent = 0;
583  }
584 
585  retval = (sign << (DstT_exp + DstT_mant)) | (exponent << DstT_mant) | mantissa;
586 
587  return bit_cast<DstT>(retval);
588 }
589 
590 template <typename X, typename Y, bool clip, bool stoch>
592 {
593  return bit_cast<Y>(run_cast_to_f8<X, Y, clip, stoch>(x, rng));
594 }
595 
596 #if CK_TILE_FP8_CVT_DEVICE
600 template <fp8_interpretation interpret, bool saturate, bool stochastic_rounding = false>
601 CK_TILE_DEVICE uint8_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
602 {
603  uint8_t i8data;
604  union
605  {
606  float fval;
607  unsigned int i32val;
608  unsigned char i8val[4]; // NOTE: not endian independent
609  } val;
610 
611  unsigned int ival = 0;
612  val.fval = v;
613 
614  if constexpr(saturate)
615  {
616  if constexpr(interpret == fp8_interpretation::E4M3_FNUZ)
617  {
618  if((val.i32val & 0x7F800000) != 0x7F800000)
619  {
620  val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
621  }
622  }
623  else if constexpr(interpret == fp8_interpretation::E4M3_OCP)
624  { // OCP type
625  if((val.i32val & 0x7F800000) != 0x7F800000)
626  {
627  val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
628  }
629  }
630  else
631  {
632  if((val.i32val & 0x7F800000) != 0x7F800000)
633  {
634  val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
635  }
636  }
637  }
638 
639  if constexpr(stochastic_rounding)
640  {
641  ival = (interpret == fp8_interpretation::E4M3_FNUZ) ||
642  (interpret == fp8_interpretation::E4M3_OCP)
643  ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
644  : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
645  val.i32val = ival;
646  i8data = val.i8val[0]; // little endian
647  }
648  else
649  { // RNE CVT
650  ival = (interpret == fp8_interpretation::E4M3_FNUZ) ||
651  (interpret == fp8_interpretation::E4M3_OCP)
652  ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
653  : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
654  val.fval,
655  ival,
656  false); // false -> WORD0
657  val.i32val = ival;
658  i8data = val.i8val[0];
659  }
660  return i8data;
661 }
662 #endif // CK_TILE_FP8_CVT_DEVICE
663 
664 } // namespace impl
665 
679 template <typename SrcT, typename DstT>
681 {
682  constexpr bool clip = true;
683  constexpr int seed = 42;
684  uint32_t rng = prand_generator_t<SrcT, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
685 #if CK_TILE_FP8_CVT_DEVICE
686  return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip, true>(x, rng);
687 #else
688  return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
689  impl::cast_to_f8<SrcT, DstT, clip, true>(x, rng));
690 #endif
691 }
692 
705 template <typename SrcT, typename DstT>
707 {
708  constexpr bool clip = true;
709 #if CK_TILE_FP8_CVT_DEVICE
710  return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip, false>(x, 0);
711 #else
712  return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
713  impl::cast_to_f8<SrcT, DstT, clip, false>(x, 0));
714 #endif
715 }
716 
717 template <fp8_rounding_mode rounding>
719 {
720  if constexpr(rounding == fp8_rounding_mode::standard)
721  {
722  return float_to_fp8_rtn_raw<float, fp8_t>(x);
723  }
724  else if constexpr(rounding == fp8_rounding_mode::stochastic)
725  {
726  return float_to_fp8_sr_raw<float, fp8_t>(x);
727  }
728  else
729  {
730  return fp8_raw_t{0};
731  }
732 }
733 
734 template <fp8_rounding_mode rounding>
736 {
737  if constexpr(rounding == fp8_rounding_mode::standard)
738  {
739  return float_to_fp8_rtn_raw<float, bf8_t>(x);
740  }
741  else if constexpr(rounding == fp8_rounding_mode::stochastic)
742  {
743  return float_to_fp8_sr_raw<float, bf8_t>(x);
744  }
745  else
746  {
747  return bf8_raw_t{0};
748  }
749 }
750 
752 {
753 #if CK_TILE_FP8_CVT_DEVICE
754  float fval;
755  uint32_t i32val = static_cast<uint32_t>(x);
756  fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
757  // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
758  return fval;
759 #else
760  return impl::run_cast_from_f8<fp8_t, float>(bit_cast<fp8_t>(x));
761 #endif
762 }
763 
765 {
766 #if CK_TILE_FP8_CVT_DEVICE
767  float fval;
768  uint32_t i32val = static_cast<uint32_t>(x);
769  fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
770  // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
771  return fval;
772 #else
773  return impl::run_cast_from_f8<bf8_t, float>(bit_cast<bf8_t>(x));
774 #endif
775 }
776 
777 template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
779 {
780  return bit_cast<fp8_t>(float_to_fp8_raw(x, constant<rounding>{}));
781 }
782 
783 template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
785 {
786  return bit_cast<bf8_t>(float_to_bf8_raw(x, constant<rounding>{}));
787 }
788 
789 CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x) { return fp8_to_float_raw(bit_cast<fp8_raw_t>(x)); }
790 
791 CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x) { return bf8_to_float_raw(bit_cast<bf8_raw_t>(x)); }
792 
793 template <class T>
794 struct numeric;
795 
796 #if CK_TILE_USE_OCP_FP8
797 template <>
798 struct numeric<fp8_t>
799 {
800  // minimum finite value, or minimum positive normal value
801  CK_TILE_HOST_DEVICE static constexpr fp8_t min()
802  {
803  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08)); // 0b00001000 = 2^-6
804  }
805 
806  // minumum finite value
807  CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
808  {
809  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xfe)); // 0b11111110 = -448
810  }
811 
812  // maximum finite value
813  CK_TILE_HOST_DEVICE static constexpr fp8_t max()
814  {
815  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7e)); // 0b01111110 = 448
816  }
817 
818  // difference between 1.0 and next representable f8 value (1.125)
819  // returns fp8_t(0.125)
820  CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon()
821  {
822  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20)); // 0.125
823  }
824 
825  // rounding error (0.0625)
826  // half of epsilon
827  CK_TILE_HOST_DEVICE static constexpr fp8_t round_error()
828  {
829  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x18)); // 0.0625
830  }
831 
832  // quiet NaN
833  CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN()
834  {
835  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7F)); // 0b01111111
836  }
837 
838  // signaling NaN
839  CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN()
840  {
841  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xFF)); // 0b11111111
842  }
843 
844  // smallest positive subnormal value
845  CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min()
846  {
847  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
848  }
849 
850  CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
851  {
852  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
853  }
854 };
855 
856 template <>
857 struct numeric<bf8_t>
858 {
859  // minimum finite value, or minimum positive normalized value for float
860  CK_TILE_HOST_DEVICE static constexpr bf8_t min()
861  {
862  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04)); // 0b00000100 = 2^-14
863  }
864 
865  // minumum finite value
866  CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
867  {
868  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xfb)); // 0b11111011 = -57344
869  }
870 
871  // maximum finite value
872  CK_TILE_HOST_DEVICE static constexpr bf8_t max()
873  {
874  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7b)); // 0b01111011 = 57344
875  }
876 
877  // difference between 1.0 and next representable bf8 value (1.25)
878  CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon()
879  {
880  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34)); // 0.25
881  }
882 
883  // rounding error (0.125)
884  // half of epsilon
885  CK_TILE_HOST_DEVICE static constexpr bf8_t round_error()
886  {
887  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x30)); // 0.125
888  }
889 
890  // positive infinity value
891  CK_TILE_HOST_DEVICE static constexpr bf8_t infinity()
892  {
893  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7c)); // 0b01111100
894  }
895 
896  // quiet NaN
897  CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN()
898  {
899  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7F)); // 0b01111111
900  }
901 
902  // signaling NaN
903  CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN()
904  {
905  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xFF));
906  }
907 
908  // smallest positive subnormal value
909  CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min()
910  {
911  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
912  }
913 
914  CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
915  {
916  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
917  }
918 };
919 #else
920 template <>
921 struct numeric<fp8_t>
922 {
923  // minimum finite value, or minimum positive normalized value for float
924  CK_TILE_HOST_DEVICE static constexpr fp8_t min()
925  {
926  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08));
927  }
928 
929  // minumum finite value
930  CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
931  {
932  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xff));
933  }
934 
935  // maximum finite value
936  CK_TILE_HOST_DEVICE static constexpr fp8_t max()
937  {
938  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7f));
939  }
940 
941  // difference between 1.0 and next value representable by float
942  CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon()
943  {
944  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20));
945  }
946 
947  // maximum rounding error
948  // bin : 7 6543 210
949  // bits: s eeee mmm
950  // 0 0110 000 (0.5)
951  //
953  {
954  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x30));
955  }
956 
957  // positive infinity value
958  CK_TILE_HOST_DEVICE static constexpr fp8_t infinity()
959  {
960  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
961  }
962 
963  // quiet NaN
965  {
966  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
967  }
968 
969  // signaling NaN
971  {
972  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
973  }
974 
975  // smallest positive subnormal value
977  {
978  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
979  }
980 
981  CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
982  {
983  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
984  }
985 };
986 
987 template <>
988 struct numeric<bf8_t>
989 {
990  // minimum finite value, or minimum positive normalized value for float
991  CK_TILE_HOST_DEVICE static constexpr bf8_t min()
992  {
993  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04));
994  }
995 
996  // minumum finite value
997  CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
998  {
999  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xff));
1000  }
1001 
1002  // maximum finite value
1003  CK_TILE_HOST_DEVICE static constexpr bf8_t max()
1004  {
1005  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7f));
1006  }
1007 
1008  // difference between 1.0 and next value representable by float
1009  CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon()
1010  {
1011  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34));
1012  }
1013 
1014  // maximum rounding error
1015  // bin : 7 65432 10
1016  // bits: s eeeee mm
1017  // 0 01110 00 (0.5)
1018  //
1020  {
1021  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x38));
1022  }
1023 
1024  // positive infinity value
1026  {
1027  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
1028  }
1029 
1030  // quiet NaN
1032  {
1033  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
1034  }
1035 
1036  // signaling NaN
1038  {
1039  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
1040  }
1041 
1042  // smallest positive subnormal value
1044  {
1045  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
1046  }
1047 
1048  CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
1049  {
1050  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
1051  }
1052 };
1053 #endif
1054 
1055 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1058 #endif
1059 
1060 // math
1061 template <typename T>
1063 {
1064  static_assert(std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>,
1065  "Only fp8_t and bf8_t are supported");
1066  return bit_cast<T>(static_cast<uint8_t>(bit_cast<uint8_t>(x) & numeric_traits<T>::abs_mask));
1067 }
1068 
1070 bool isnan(const fp8_t& x)
1071 {
1072  uint8_t xx = bit_cast<fp8_raw_t>(x);
1073 
1074 #if CK_TILE_USE_OCP_FP8
1075  return (xx & 0x7f) == 0x7f;
1076 #else
1077  return xx == 0x80;
1078 #endif
1079 }
1080 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1082 fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
1083 
1085 fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__ocml_exp_f32(static_cast<float>(x))); };
1086 
1088 fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); };
1089 
1091 fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
1092 #endif
1093 
1095 bool isnan(const bf8_t& x)
1096 {
1097  uint8_t xx = bit_cast<bf8_raw_t>(x);
1098 
1099 #if CK_TILE_USE_OCP_FP8
1100  return (xx & 0x7f) > 0x7c;
1101 #else
1102  return xx == 0x80;
1103 #endif
1104 }
1105 
1106 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1108 bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
1109 
1111 bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__ocml_exp_f32(static_cast<float>(x))); };
1112 
1114 bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); };
1115 
1117 bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); };
1118 #endif
1119 
1120 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_FLOAT_TO_FP8_DEFAULT
Definition: config.hpp:79
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng=0)
Definition: float8.hpp:250
CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
Definition: float8.hpp:476
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
Definition: float8.hpp:591
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:432
fp8_interpretation
FP8 interpretation used in conversion algorithms.
Definition: float8.hpp:38
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant< rounding >={})
Definition: float8.hpp:778
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:751
constexpr CK_TILE_HOST_DEVICE Y bit_cast(const X &x)
Definition: bit_cast.hpp:11
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:764
fp8_rounding_mode
Definition: float8.hpp:29
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:417
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant< rounding >={})
Definition: float8.hpp:718
uint8_t fp8_raw_t
Definition: float8.hpp:205
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x)
Definition: float8.hpp:791
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_sr_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with stochastic rounding.
Definition: float8.hpp:680
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:423
CK_TILE_HOST int clz(uint32_t x)
Definition: math.hpp:264
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:404
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
uint8_t bf8_raw_t
Definition: float8.hpp:207
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant< rounding >={})
Definition: float8.hpp:784
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition: bfloat16.hpp:410
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_rtn_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with rounding to nearest ev...
Definition: float8.hpp:706
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x)
Definition: float8.hpp:789
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant< rounding >={})
Definition: float8.hpp:735
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:429
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
_W64 unsigned int uintptr_t
Definition: stdint.h:165
unsigned int uint32_t
Definition: stdint.h:126
unsigned char uint8_t
Definition: stdint.h:124
Definition: integral_constant.hpp:13
remove_cvref_t< T > type
Definition: vector_type.hpp:26
static constexpr CK_TILE_HOST_DEVICE bf8_t min()
Definition: float8.hpp:991
static constexpr CK_TILE_HOST_DEVICE bf8_t quiet_NaN()
Definition: float8.hpp:1031
static constexpr CK_TILE_HOST_DEVICE bf8_t lowest()
Definition: float8.hpp:997
static constexpr CK_TILE_HOST_DEVICE bf8_t round_error()
Definition: float8.hpp:1019
static constexpr CK_TILE_HOST_DEVICE bf8_t signaling_NaN()
Definition: float8.hpp:1037
static constexpr CK_TILE_HOST_DEVICE bf8_t denorm_min()
Definition: float8.hpp:1043
static constexpr CK_TILE_HOST_DEVICE bf8_t epsilon()
Definition: float8.hpp:1009
static constexpr CK_TILE_HOST_DEVICE bf8_t infinity()
Definition: float8.hpp:1025
static constexpr CK_TILE_HOST_DEVICE bf8_t max()
Definition: float8.hpp:1003
static constexpr CK_TILE_HOST_DEVICE bf8_t zero()
Definition: float8.hpp:1048
static constexpr CK_TILE_HOST_DEVICE fp8_t signaling_NaN()
Definition: float8.hpp:970
static constexpr CK_TILE_HOST_DEVICE fp8_t zero()
Definition: float8.hpp:981
static constexpr CK_TILE_HOST_DEVICE fp8_t min()
Definition: float8.hpp:924
static constexpr CK_TILE_HOST_DEVICE fp8_t lowest()
Definition: float8.hpp:930
static constexpr CK_TILE_HOST_DEVICE fp8_t epsilon()
Definition: float8.hpp:942
static constexpr CK_TILE_HOST_DEVICE fp8_t quiet_NaN()
Definition: float8.hpp:964
static constexpr CK_TILE_HOST_DEVICE fp8_t max()
Definition: float8.hpp:936
static constexpr CK_TILE_HOST_DEVICE fp8_t denorm_min()
Definition: float8.hpp:976
static constexpr CK_TILE_HOST_DEVICE fp8_t round_error()
Definition: float8.hpp:952
static constexpr CK_TILE_HOST_DEVICE fp8_t infinity()
Definition: float8.hpp:958
bf8_raw_t bitwise_type
Definition: float8.hpp:231
fp8_raw_t bitwise_type
Definition: float8.hpp:213
Definition: numeric.hpp:81
static constexpr int PackedSize
Definition: numeric.hpp:82
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T lowest()
Definition: numeric.hpp:23
static constexpr CK_TILE_HOST_DEVICE T min()
Definition: numeric.hpp:20
static constexpr CK_TILE_HOST_DEVICE T quiet_NaN()
Definition: numeric.hpp:41
static constexpr CK_TILE_HOST_DEVICE T signaling_NaN()
Definition: numeric.hpp:47
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26
static constexpr CK_TILE_HOST_DEVICE T round_error()
Definition: numeric.hpp:32
static constexpr CK_TILE_HOST_DEVICE T zero()
Definition: numeric.hpp:58
static constexpr CK_TILE_HOST_DEVICE T denorm_min()
Definition: numeric.hpp:53
static constexpr CK_TILE_HOST_DEVICE T epsilon()
Definition: numeric.hpp:29
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38
Definition: random.hpp:17
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:106