include/ck_tile/core/numeric/float8.hpp Source File

include/ck_tile/core/numeric/float8.hpp Source File#

Composable Kernel: include/ck_tile/core/numeric/float8.hpp Source File
float8.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
12 #include <stdint.h>
13 #include <type_traits>
14 
15 #pragma once
16 
17 #if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
18 #define CK_TILE_FP8_CVT_DEVICE 1
19 #else
20 #define CK_TILE_FP8_CVT_DEVICE 0
21 #endif
22 
23 namespace ck_tile {
24 
25 // fp8 rounding modes
26 // use standard for rounding to nearest, the faster one
27 // use stochastic for stochastic rounding, helps to avoid error accumulation
29 {
30  standard = 0,
32 };
33 
38 {
39  E4M3_OCP = 0, // OCP FP8 E4M3
40  E5M2_OCP = 1, // OCP BF8 E5M2
41  E4M3_FNUZ = 2, // FNUZ FP8 E4M3
42  E5M2_FNUZ = 3, // FNUZ BF8 E5M2
43 };
44 
45 /*
46  * ______________FNUZ_________________ | ______________OCP________________
47  * e4m3 e5m2 | e4m3 e5m2
48  * bias : 8 16 | 7 15
49  * inf : 1.0000.000 1.00000.00 | N/A s.11111.00
50  * Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
51  * zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
52  * Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
53  * Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11
54  * 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
55  * Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
56  * 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
57  * Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
58  * 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
59  */
60 
61 template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
62 CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant<rounding> = {});
63 
64 template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
65 CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant<rounding> = {});
66 
69 
70 #if CK_TILE_USE_CUSTOM_DATA_TYPE
71 struct alignas(1) float8_e4m3_t
72 {
73  static constexpr int exponent = 4;
74  static constexpr int mantissa = 3;
75 #if CK_TILE_USE_OCP_FP8
76  static constexpr int bias = 7; // OCP
77 #else
78  static constexpr int bias = 8; // FNUZ
79 #endif
80  using raw_type = uint8_t;
81  raw_type data;
82 
84  static constexpr float8_e4m3_t bit_cast(raw_type x)
85  {
86  float8_e4m3_t y;
87  y.data = x;
88  return y;
89  }
90 
91  // constructor
92  constexpr float8_e4m3_t() : data() {}
93 
94  // construct from float
96  explicit constexpr float8_e4m3_t(const float& x) : data(float_to_fp8_raw(x)) {}
97 
98  // construct from int
100  explicit constexpr float8_e4m3_t(const int& x) : data(float_to_fp8_raw(static_cast<float>(x)))
101  {
102  }
103 
104  // construct from unsigned int
106  explicit constexpr float8_e4m3_t(const unsigned int& x)
107  : data(float_to_fp8_raw(static_cast<float>(x)))
108  {
109  }
110 
111  // cast to float
113  explicit constexpr operator float() const { return fp8_to_float_raw(data); }
114 
115  // cast to int
117  explicit constexpr operator int() const { return static_cast<int>(fp8_to_float_raw(data)); }
118 
119  // internal access
121  constexpr raw_type& get() { return data; }
122 
124  constexpr raw_type get() const { return data; }
125 };
126 using fp8_t = float8_e4m3_t;
127 using fp8_raw_t = typename fp8_t::raw_type;
128 
129 struct alignas(1) float8_e5m2_t
130 {
131  static constexpr int exponent = 5;
132  static constexpr int mantissa = 2;
133 #if CK_TILE_USE_OCP_FP8
134  static constexpr int bias = 15; // OCP
135 #else
136  static constexpr int bias = 16; // FNUZ
137 #endif
138  using raw_type = uint8_t;
139  raw_type data;
140 
142  static constexpr float8_e5m2_t bit_cast(raw_type x)
143  {
144  float8_e5m2_t y;
145  y.data = x;
146  return y;
147  }
148 
149  // constructor
150  constexpr float8_e5m2_t() : data() {}
151 
152  // construct from float
154  explicit constexpr float8_e5m2_t(const float& x) : data(float_to_bf8_raw(x)) {}
155 
156  // construct from int
158  explicit constexpr float8_e5m2_t(const int& x) : data(float_to_bf8_raw(static_cast<float>(x)))
159  {
160  }
161 
162  // construct from unsigned int
164  explicit constexpr float8_e5m2_t(const unsigned int& x)
165  : data(float_to_bf8_raw(static_cast<float>(x)))
166  {
167  }
168 
169  // cast to float
171  explicit constexpr operator float() const { return bf8_to_float_raw(data); }
172 
173  // cast to int
175  explicit constexpr operator int() const { return static_cast<int>(bf8_to_float_raw(data)); }
176 
177  // internal access
179  constexpr raw_type& get() { return data; }
180 
182  constexpr raw_type get() const { return data; }
183 };
184 using bf8_t = float8_e5m2_t;
185 using bf8_raw_t = typename bf8_t::raw_type;
186 
187 template <typename>
188 struct native_t;
189 
190 template <>
191 struct native_t<fp8_t>
192 {
193  using type = _BitInt(8);
194 };
195 
196 template <>
197 struct native_t<bf8_t>
198 {
199  using type = unsigned _BitInt(8);
200 };
201 
202 #else
203 
204 using fp8_t = _BitInt(8);
205 using fp8_raw_t = uint8_t;
206 using bf8_t = unsigned _BitInt(8);
207 using bf8_raw_t = uint8_t;
208 #endif
209 
210 template <typename T>
211 struct numeric_traits;
212 
213 template <>
215 {
217 
218  static constexpr int exp = 4;
219  static constexpr int mant = 3;
220 #if CK_TILE_USE_OCP_FP8
221  static constexpr int bias = 7;
222  static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_OCP;
223 #else
224  static constexpr int bias = 8;
225  static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_FNUZ;
226 #endif
227  static constexpr uint8_t abs_mask = 0x7F;
228 };
229 
230 template <>
232 {
234 
235  static constexpr int exp = 5;
236  static constexpr int mant = 2;
237 #if CK_TILE_USE_OCP_FP8
238  static constexpr int bias = 15;
239  static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_OCP;
240 #else
241  static constexpr int bias = 16;
242  static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_FNUZ;
243 #endif
244  static constexpr uint8_t abs_mask = 0x7F;
245 };
246 
247 // below is sw fp8 conversion, not utilizing hw instruction
248 namespace impl {
249 
250 template <typename SrcT, typename DstT, bool clip = true, bool stoch = false>
251 CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
252 {
253  static_assert(std::is_same<DstT, fp8_t>::value || std::is_same<DstT, bf8_t>::value,
254  "DstT type must be fp8 or bf8.");
255 
256  constexpr bool is_half = std::is_same<SrcT, half_t>::value;
257  constexpr bool is_float = std::is_same<SrcT, float>::value;
258  static_assert(is_half || is_float, "Only half and float can be cast to f8");
259 
260  // fp8/bf8 type exponent/mantissa layout
261  constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
262  constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
263  constexpr bool is_fnuz =
266 
267  constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
268  constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
269 
270  using SrcT_bitwise = typename numeric_traits<SrcT>::bitwise_type;
271  SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
272 
273  unsigned long long head, mantissa;
274  int exponent, bias;
275  unsigned int sign;
276  unsigned long long fInf, abs_mask;
277 
278  head = src_bitwise & numeric_traits<SrcT>::head_mask;
279  mantissa = src_bitwise & numeric_traits<SrcT>::mant_mask;
280  exponent = (head >> SrcT_mant) & numeric_traits<SrcT>::exp_mask;
281  sign = head >> (SrcT_exp + SrcT_mant);
285 
286  unsigned int signed_inf = 0;
287  unsigned int nan = 0;
288  if constexpr(is_fnuz)
289  {
290  signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
291  nan = 0x80;
292  }
293  else
294  {
295  if constexpr(DstT_exp == 4)
296  { // e4m3
297  signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
298  }
299  else
300  { // e5m2
301  signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
302  }
303  nan = (sign << 7) + 0x7f;
304  }
305  // Max values
306  unsigned long long ifmax = 0;
307  if constexpr(is_float)
308  {
309  if constexpr(DstT_exp == 5)
310  {
311  ifmax = 0x47600000;
312  }
313  else
314  {
315  if constexpr(is_fnuz)
316  {
317  ifmax = 0x43700000;
318  }
319  else
320  {
321  ifmax = 0x43E00000;
322  }
323  }
324  }
325  else if constexpr(is_half)
326  {
327  if constexpr(DstT_exp == 5)
328  {
329  ifmax = 0x7B00;
330  }
331  else
332  {
333  if constexpr(is_fnuz)
334  {
335  ifmax = 0x5B80;
336  }
337  else
338  {
339  ifmax = 0x5F00;
340  }
341  }
342  }
343 
344  // Deal with inf and NaNs
345  if((src_bitwise & fInf) == fInf)
346  {
347  if constexpr(is_fnuz)
348  return signed_inf;
349 
350  return mantissa != 0 ? nan : signed_inf;
351  }
352 
353  if((src_bitwise & abs_mask) > ifmax)
354  {
355  return signed_inf;
356  }
357 
358  if(src_bitwise == 0)
359  {
360  return 0;
361  }
362 
363  // First need to check if it is normal or denorm as there is a difference of
364  // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
365  // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
366  // to mantissa and truncate. And for RNE, no need to add rng. Then probably
367  // need to check whether there is carry and adjust exponent and mantissa again
368 
369  // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
370  // bits
371  const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0);
372  const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
373  // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
374  // f8_exponent is the converted f8 exponent with bias encoding
375  // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
376  // the difference needs to be adjusted and mantissa shifted
377  int act_exponent, f8_exponent, exponent_diff;
378 
379  if(exponent == 0)
380  { // fp32/fp16 is in denormal.
381  /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
382  mostly concern fp16 here. In this case, f8 is usually in denormal. But there
383  could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
384  exponent bias 16. It means that there are some numbers in fp16 denormal but they
385  are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
386  where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
387  (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
388  act_exponent = exponent - bias + 1;
389  exponent_diff = f8_denormal_act_exponent -
390  act_exponent; // actual exponent is exponent-bias+1 as it is denormal
391  }
392  else
393  { // fp32/fp16 is normal with implicit 1
394  act_exponent = exponent - bias;
395  if(act_exponent <= f8_denormal_act_exponent)
396  {
397  /* This is the case where fp32/fp16 is normal but it is in f8 denormal
398  range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
399  actual exponent is -7, it is actually larger due to the implicit 1,
400  Therefore it needs to be adjust to -6 and mantissa shift right by 1.
401  So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
402  exponent_diff = f8_denormal_act_exponent - act_exponent;
403  }
404  else
405  { // both fp32/fp16 and f8 are in normal range
406  exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
407  // for this case, act_exponent could be larger. Just
408  // that it does not need shift mantissa
409  }
410  mantissa += (1ull << SrcT_mant); // Add the implicit 1 into mantissa
411  }
412 
413  bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
414  (1ull << (SrcT_mant - DstT_mant + exponent_diff - 1));
415  /* This part is a bit tricky. The judgment of whether it is a tie needs to be
416  done before we shift right as shift right could rip off some residual part and
417  make something not midpoint look like midpoint. For example, the fp16 number
418  0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
419  by 4 bits, it would look like midpoint.
420  */
421 
422  if(exponent_diff > 0)
423  mantissa >>= exponent_diff;
424  else if(exponent_diff == -1)
425  mantissa <<= -exponent_diff;
426  bool implicit_one = mantissa & (1ull << SrcT_mant);
427  // if there is no implicit 1, it means the f8 is denormal and need to adjust
428  // to denorm exponent
429  f8_exponent =
430  (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
431 
432  // Now we have the exponent and mantissa adjusted
433  unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1;
434  bool odd =
435  mantissa & (1ull << (SrcT_mant -
436  DstT_mant)); // if the least significant bit that is not truncated is 1
437  mantissa +=
438  (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
439 
440  // Now we deal with overflow
441  if(f8_exponent == 0)
442  {
443  if((1ull << SrcT_mant) & mantissa)
444  {
445  f8_exponent = 1; // denormal overflow to become normal, promote exponent
446  }
447  }
448  else
449  {
450  if((1ull << (SrcT_mant + 1)) & mantissa)
451  {
452  mantissa >>= 1;
453  f8_exponent++;
454  }
455  }
456 
457  mantissa >>= (SrcT_mant - DstT_mant);
458 
459  // above range: quantize to maximum possible float of the same sign
460  const int max_exp = (1 << DstT_exp) - 1;
461  if(f8_exponent > max_exp)
462  {
463  if constexpr(clip)
464  {
465  mantissa = (1 << DstT_mant) - 1;
466  f8_exponent = max_exp;
467  }
468  else
469  {
470  return signed_inf;
471  }
472  }
473 
474  if(f8_exponent == 0 && mantissa == 0)
475  return is_fnuz ? 0 : (sign << 7);
476  mantissa &= (1 << DstT_mant) - 1;
477  return (sign << 7) | (f8_exponent << DstT_mant) | mantissa;
478 }
479 
480 template <typename SrcT, typename DstT, bool clip = true>
482 {
483  static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
484  "SrcT type must be fp8 or bf8.");
485  constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
486  constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
487  constexpr bool is_fnuz =
490 
491  constexpr bool is_half = std::is_same<DstT, half_t>::value;
492  constexpr bool is_float = std::is_same<DstT, float>::value;
493  static_assert(is_half || is_float, "DstT type must be half_t or float.");
494 
495  // destination type exponent/mantissa layout
496  constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
497  constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
498 
499  constexpr DstT fInf = bit_cast<DstT>(numeric_traits<DstT>::Inf);
500  constexpr DstT fNegInf = bit_cast<DstT>(numeric_traits<DstT>::NegInf);
501  constexpr DstT fNaN = bit_cast<DstT>(numeric_traits<DstT>::NaN);
502  constexpr DstT fNeg0 = bit_cast<DstT>(numeric_traits<DstT>::Neg0);
503 
504  DstT fmax{0}, fmin{0};
505  // Max number in e5m2 57344
506  if constexpr(is_half)
507  {
508  fmax = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0x7B00));
509  fmin = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0xFB00));
510  }
511  else if constexpr(is_float)
512  {
513  fmax = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0x47600000));
514  fmin = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0xC7600000));
515  }
516 
517  if(x == 0)
518  {
519  return 0;
520  }
521 
522  unsigned long long sign = x >> 7;
523  unsigned long long mantissa = x & ((1 << SrcT_mant) - 1);
524  int exponent = (x & 0x7F) >> SrcT_mant;
525  if constexpr(is_fnuz)
526  {
527  if(x == 0x80)
528  {
529  return fNaN;
530  }
531  }
532  else
533  {
534  if(x == 0x80)
535  {
536  return fNeg0;
537  }
538  if constexpr(SrcT_exp == 4)
539  { // e4m3
540  if((x & 0x7F) == 0x7F)
541  {
542  return fNaN;
543  }
544  }
545  else if((x & 0x7C) == 0x7C)
546  { // e5m2
547  if((x & 0x3) == 0)
548  {
549  if constexpr(clip)
550  {
551  return sign ? fmin : fmax;
552  }
553  return sign ? fNegInf : fInf;
554  }
555  return fNaN;
556  }
557  }
558 
559  typename numeric_traits<DstT>::bitwise_type retval;
560 
561  if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
562  {
563  retval = x << 8;
564  return bit_cast<DstT>(retval);
565  }
566 
567  const int exp_low_cutoff =
568  (1 << (DstT_exp - 1)) - (1 << (SrcT_exp - 1)) + 1 - (is_fnuz ? 1 : 0);
569 
570  // subnormal input
571  if(exponent == 0)
572  {
573  int sh = 1 + clz(mantissa) - (32 - SrcT_mant);
574  mantissa <<= sh;
575  exponent += 1 - sh;
576  mantissa &= ((1ull << SrcT_mant) - 1);
577  }
578  exponent += exp_low_cutoff - 1;
579  mantissa <<= DstT_mant - SrcT_mant;
580 
581  // subnormal output (occurs when DstT is half_t, we=5, is_fnuz=true)
582  if(exponent <= 0)
583  {
584  mantissa |= 1 << DstT_mant;
585  mantissa >>= 1 - exponent;
586  exponent = 0;
587  }
588 
589  retval = (sign << (DstT_exp + DstT_mant)) | (exponent << DstT_mant) | mantissa;
590 
591  return bit_cast<DstT>(retval);
592 }
593 
594 template <typename X, typename Y, bool clip, bool stoch>
595 CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
596 {
597  return bit_cast<Y>(run_cast_to_f8<X, Y, clip, stoch>(x, rng));
598 }
599 
600 #if CK_TILE_FP8_CVT_DEVICE
604 template <fp8_interpretation interpret, bool saturate, bool stochastic_rounding = false>
605 CK_TILE_DEVICE uint8_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
606 {
607  uint8_t i8data;
608  union
609  {
610  float fval;
611  unsigned int i32val;
612  unsigned char i8val[4]; // NOTE: not endian independent
613  } val;
614 
615  unsigned int ival = 0;
616  val.fval = v;
617 
618  if constexpr(saturate)
619  {
620  if constexpr(interpret == fp8_interpretation::E4M3_FNUZ)
621  {
622  if((val.i32val & 0x7F800000) != 0x7F800000)
623  {
624  val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
625  }
626  }
627  else if constexpr(interpret == fp8_interpretation::E4M3_OCP)
628  { // OCP type
629  if((val.i32val & 0x7F800000) != 0x7F800000)
630  {
631  val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
632  }
633  }
634  else
635  {
636  if((val.i32val & 0x7F800000) != 0x7F800000)
637  {
638  val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
639  }
640  }
641  }
642 
643  if constexpr(stochastic_rounding)
644  {
645  ival = (interpret == fp8_interpretation::E4M3_FNUZ) ||
646  (interpret == fp8_interpretation::E4M3_OCP)
647  ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
648  : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
649  val.i32val = ival;
650  i8data = val.i8val[0]; // little endian
651  }
652  else
653  { // RNE CVT
654  ival = (interpret == fp8_interpretation::E4M3_FNUZ) ||
655  (interpret == fp8_interpretation::E4M3_OCP)
656  ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
657  : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
658  val.fval,
659  ival,
660  false); // false -> WORD0
661  val.i32val = ival;
662  i8data = val.i8val[0];
663  }
664  return i8data;
665 }
666 #endif // CK_TILE_FP8_CVT_DEVICE
667 
668 } // namespace impl
669 
683 template <typename SrcT, typename DstT>
685 {
686  constexpr bool clip = true;
687  constexpr int seed = 42;
688  uint32_t rng = prand_generator_t<SrcT, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
689 #if CK_TILE_FP8_CVT_DEVICE
690  return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip, true>(x, rng);
691 #else
692  return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
693  impl::cast_to_f8<SrcT, DstT, clip, true>(x, rng));
694 #endif
695 }
696 
709 template <typename SrcT, typename DstT>
711 {
712  constexpr bool clip = true;
713 #if CK_TILE_FP8_CVT_DEVICE
714  return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip, false>(x, 0);
715 #else
716  return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
717  impl::cast_to_f8<SrcT, DstT, clip, false>(x, 0));
718 #endif
719 }
720 
721 template <fp8_rounding_mode rounding>
723 {
724  if constexpr(rounding == fp8_rounding_mode::standard)
725  {
726  return float_to_fp8_rtn_raw<float, fp8_t>(x);
727  }
728  else if constexpr(rounding == fp8_rounding_mode::stochastic)
729  {
730  return float_to_fp8_sr_raw<float, fp8_t>(x);
731  }
732  else
733  {
734  return fp8_raw_t{0};
735  }
736 }
737 
738 template <fp8_rounding_mode rounding>
740 {
741  if constexpr(rounding == fp8_rounding_mode::standard)
742  {
743  return float_to_fp8_rtn_raw<float, bf8_t>(x);
744  }
745  else if constexpr(rounding == fp8_rounding_mode::stochastic)
746  {
747  return float_to_fp8_sr_raw<float, bf8_t>(x);
748  }
749  else
750  {
751  return bf8_raw_t{0};
752  }
753 }
754 
756 {
757 #if CK_TILE_FP8_CVT_DEVICE
758  float fval;
759  uint32_t i32val = static_cast<uint32_t>(x);
760  fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
761  // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
762  return fval;
763 #else
764  return impl::run_cast_from_f8<fp8_t, float>(bit_cast<fp8_t>(x));
765 #endif
766 }
767 
769 {
770 #if CK_TILE_FP8_CVT_DEVICE
771  float fval;
772  uint32_t i32val = static_cast<uint32_t>(x);
773  fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
774  // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
775  return fval;
776 #else
777  return impl::run_cast_from_f8<bf8_t, float>(bit_cast<bf8_t>(x));
778 #endif
779 }
780 
781 template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
783 {
784  return bit_cast<fp8_t>(float_to_fp8_raw(x, constant<rounding>{}));
785 }
786 
787 template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
789 {
790  return bit_cast<bf8_t>(float_to_bf8_raw(x, constant<rounding>{}));
791 }
792 
793 CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x) { return fp8_to_float_raw(bit_cast<fp8_raw_t>(x)); }
794 
795 CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x) { return bf8_to_float_raw(bit_cast<bf8_raw_t>(x)); }
796 
797 template <class T>
798 struct numeric;
799 
800 #if CK_TILE_USE_OCP_FP8
801 template <>
802 struct numeric<fp8_t>
803 {
804  // minimum finite value, or minimum positive normal value
805  CK_TILE_HOST_DEVICE static constexpr fp8_t min()
806  {
807  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08)); // 0b00001000 = 2^-6
808  }
809 
810  // minumum finite value
811  CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
812  {
813  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xfe)); // 0b11111110 = -448
814  }
815 
816  // maximum finite value
817  CK_TILE_HOST_DEVICE static constexpr fp8_t max()
818  {
819  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7e)); // 0b01111110 = 448
820  }
821 
822  // difference between 1.0 and next representable f8 value (1.125)
823  // returns fp8_t(0.125)
824  CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon()
825  {
826  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20)); // 0.125
827  }
828 
829  // rounding error (0.0625)
830  // half of epsilon
831  CK_TILE_HOST_DEVICE static constexpr fp8_t round_error()
832  {
833  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x18)); // 0.0625
834  }
835 
836  // quiet NaN
837  CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN()
838  {
839  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7F)); // 0b01111111
840  }
841 
842  // signaling NaN
843  CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN()
844  {
845  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xFF)); // 0b11111111
846  }
847 
848  // smallest positive subnormal value
849  CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min()
850  {
851  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
852  }
853 
854  CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
855  {
856  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
857  }
858 };
859 
860 template <>
861 struct numeric<bf8_t>
862 {
863  // minimum finite value, or minimum positive normalized value for float
864  CK_TILE_HOST_DEVICE static constexpr bf8_t min()
865  {
866  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04)); // 0b00000100 = 2^-14
867  }
868 
869  // minumum finite value
870  CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
871  {
872  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xfb)); // 0b11111011 = -57344
873  }
874 
875  // maximum finite value
876  CK_TILE_HOST_DEVICE static constexpr bf8_t max()
877  {
878  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7b)); // 0b01111011 = 57344
879  }
880 
881  // difference between 1.0 and next representable bf8 value (1.25)
882  CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon()
883  {
884  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34)); // 0.25
885  }
886 
887  // rounding error (0.125)
888  // half of epsilon
889  CK_TILE_HOST_DEVICE static constexpr bf8_t round_error()
890  {
891  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x30)); // 0.125
892  }
893 
894  // positive infinity value
895  CK_TILE_HOST_DEVICE static constexpr bf8_t infinity()
896  {
897  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7c)); // 0b01111100
898  }
899 
900  // quiet NaN
901  CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN()
902  {
903  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7F)); // 0b01111111
904  }
905 
906  // signaling NaN
907  CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN()
908  {
909  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xFF));
910  }
911 
912  // smallest positive subnormal value
913  CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min()
914  {
915  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
916  }
917 
918  CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
919  {
920  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
921  }
922 };
923 #else
924 template <>
925 struct numeric<fp8_t>
926 {
927  // minimum finite value, or minimum positive normalized value for float
928  CK_TILE_HOST_DEVICE static constexpr fp8_t min()
929  {
930  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08));
931  }
932 
933  // minumum finite value
934  CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
935  {
936  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xff));
937  }
938 
939  // maximum finite value
940  CK_TILE_HOST_DEVICE static constexpr fp8_t max()
941  {
942  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7f));
943  }
944 
945  // difference between 1.0 and next value representable by float
946  CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon()
947  {
948  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20));
949  }
950 
951  // maximum rounding error
952  // bin : 7 6543 210
953  // bits: s eeee mmm
954  // 0 0110 000 (0.5)
955  //
957  {
958  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x30));
959  }
960 
961  // positive infinity value
962  CK_TILE_HOST_DEVICE static constexpr fp8_t infinity()
963  {
964  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
965  }
966 
967  // quiet NaN
969  {
970  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
971  }
972 
973  // signaling NaN
975  {
976  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
977  }
978 
979  // smallest positive subnormal value
981  {
982  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
983  }
984 
985  CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
986  {
987  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
988  }
989 };
990 
991 template <>
992 struct numeric<bf8_t>
993 {
994  // minimum finite value, or minimum positive normalized value for float
995  CK_TILE_HOST_DEVICE static constexpr bf8_t min()
996  {
997  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04));
998  }
999 
1000  // minumum finite value
1001  CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
1002  {
1003  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xff));
1004  }
1005 
1006  // maximum finite value
1007  CK_TILE_HOST_DEVICE static constexpr bf8_t max()
1008  {
1009  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7f));
1010  }
1011 
1012  // difference between 1.0 and next value representable by float
1013  CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon()
1014  {
1015  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34));
1016  }
1017 
1018  // maximum rounding error
1019  // bin : 7 65432 10
1020  // bits: s eeeee mm
1021  // 0 01110 00 (0.5)
1022  //
1024  {
1025  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x38));
1026  }
1027 
1028  // positive infinity value
1030  {
1031  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
1032  }
1033 
1034  // quiet NaN
1036  {
1037  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
1038  }
1039 
1040  // signaling NaN
1042  {
1043  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
1044  }
1045 
1046  // smallest positive subnormal value
1048  {
1049  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
1050  }
1051 
1052  CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
1053  {
1054  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
1055  }
1056 };
1057 #endif
1058 
1059 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1062 #endif
1063 
1064 // math
1065 template <typename T>
1067 {
1068  static_assert(std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>,
1069  "Only fp8_t and bf8_t are supported");
1070  return bit_cast<T>(static_cast<uint8_t>(bit_cast<uint8_t>(x) & numeric_traits<T>::abs_mask));
1071 }
1072 
1074 bool isnan(const fp8_t& x)
1075 {
1076  uint8_t xx = bit_cast<fp8_raw_t>(x);
1077 
1078 #if CK_TILE_USE_OCP_FP8
1079  return (xx & 0x7f) == 0x7f;
1080 #else
1081  return xx == 0x80;
1082 #endif
1083 }
1084 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1086 fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
1087 
1089 fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__ocml_exp_f32(static_cast<float>(x))); };
1090 
1092 fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); };
1093 
1095 fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
1096 #endif
1097 
1099 bool isnan(const bf8_t& x)
1100 {
1101  uint8_t xx = bit_cast<bf8_raw_t>(x);
1102 
1103 #if CK_TILE_USE_OCP_FP8
1104  return (xx & 0x7f) > 0x7c;
1105 #else
1106  return xx == 0x80;
1107 #endif
1108 }
1109 
1110 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1112 bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
1113 
1115 bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__ocml_exp_f32(static_cast<float>(x))); };
1116 
1118 bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); };
1119 
1121 bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); };
1122 #endif
1123 
1124 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_FLOAT_TO_FP8_DEFAULT
Definition: config.hpp:77
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng=0)
Definition: float8.hpp:251
CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
Definition: float8.hpp:481
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
Definition: float8.hpp:595
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:423
fp8_interpretation
FP8 interpretation used in conversion algorithms.
Definition: float8.hpp:38
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant< rounding >={})
Definition: float8.hpp:782
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:755
constexpr CK_TILE_HOST_DEVICE Y bit_cast(const X &x)
Definition: bit_cast.hpp:11
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:768
fp8_rounding_mode
Definition: float8.hpp:29
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:408
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant< rounding >={})
Definition: float8.hpp:722
uint8_t fp8_raw_t
Definition: float8.hpp:205
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x)
Definition: float8.hpp:795
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_sr_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with stochastic rounding.
Definition: float8.hpp:684
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:414
CK_TILE_HOST int clz(uint32_t x)
Definition: math.hpp:264
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:395
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
uint8_t bf8_raw_t
Definition: float8.hpp:207
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant< rounding >={})
Definition: float8.hpp:788
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition: bfloat16.hpp:401
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_rtn_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with rounding to nearest ev...
Definition: float8.hpp:710
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x)
Definition: float8.hpp:793
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant< rounding >={})
Definition: float8.hpp:739
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:420
Definition: integral_constant.hpp:13
remove_cvref_t< T > type
Definition: vector_type.hpp:25
static constexpr CK_TILE_HOST_DEVICE bf8_t min()
Definition: float8.hpp:995
static constexpr CK_TILE_HOST_DEVICE bf8_t quiet_NaN()
Definition: float8.hpp:1035
static constexpr CK_TILE_HOST_DEVICE bf8_t lowest()
Definition: float8.hpp:1001
static constexpr CK_TILE_HOST_DEVICE bf8_t round_error()
Definition: float8.hpp:1023
static constexpr CK_TILE_HOST_DEVICE bf8_t signaling_NaN()
Definition: float8.hpp:1041
static constexpr CK_TILE_HOST_DEVICE bf8_t denorm_min()
Definition: float8.hpp:1047
static constexpr CK_TILE_HOST_DEVICE bf8_t epsilon()
Definition: float8.hpp:1013
static constexpr CK_TILE_HOST_DEVICE bf8_t infinity()
Definition: float8.hpp:1029
static constexpr CK_TILE_HOST_DEVICE bf8_t max()
Definition: float8.hpp:1007
static constexpr CK_TILE_HOST_DEVICE bf8_t zero()
Definition: float8.hpp:1052
static constexpr CK_TILE_HOST_DEVICE fp8_t signaling_NaN()
Definition: float8.hpp:974
static constexpr CK_TILE_HOST_DEVICE fp8_t zero()
Definition: float8.hpp:985
static constexpr CK_TILE_HOST_DEVICE fp8_t min()
Definition: float8.hpp:928
static constexpr CK_TILE_HOST_DEVICE fp8_t lowest()
Definition: float8.hpp:934
static constexpr CK_TILE_HOST_DEVICE fp8_t epsilon()
Definition: float8.hpp:946
static constexpr CK_TILE_HOST_DEVICE fp8_t quiet_NaN()
Definition: float8.hpp:968
static constexpr CK_TILE_HOST_DEVICE fp8_t max()
Definition: float8.hpp:940
static constexpr CK_TILE_HOST_DEVICE fp8_t denorm_min()
Definition: float8.hpp:980
static constexpr CK_TILE_HOST_DEVICE fp8_t round_error()
Definition: float8.hpp:956
static constexpr CK_TILE_HOST_DEVICE fp8_t infinity()
Definition: float8.hpp:962
bf8_raw_t bitwise_type
Definition: float8.hpp:233
fp8_raw_t bitwise_type
Definition: float8.hpp:216
Definition: bfloat16.hpp:380
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T lowest()
Definition: numeric.hpp:23
static constexpr CK_TILE_HOST_DEVICE T min()
Definition: numeric.hpp:20
static constexpr CK_TILE_HOST_DEVICE T quiet_NaN()
Definition: numeric.hpp:41
static constexpr CK_TILE_HOST_DEVICE T signaling_NaN()
Definition: numeric.hpp:47
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26
static constexpr CK_TILE_HOST_DEVICE T round_error()
Definition: numeric.hpp:32
static constexpr CK_TILE_HOST_DEVICE T zero()
Definition: numeric.hpp:58
static constexpr CK_TILE_HOST_DEVICE T denorm_min()
Definition: numeric.hpp:53
static constexpr CK_TILE_HOST_DEVICE T epsilon()
Definition: numeric.hpp:29
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38
Definition: random.hpp:17
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:102