include/ck/utility/amd_ck_fp8.hpp Source File

include/ck/utility/amd_ck_fp8.hpp Source File#

Composable Kernel: include/ck/utility/amd_ck_fp8.hpp Source File
amd_ck_fp8.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/ck.hpp"
9 #include "ck/utility/type.hpp"
10 
11 #ifdef CK_USE_FNUZ_FP8
12 #define CK_USE_FNUZ_FP8 1
13 #else
14 #define CK_USE_FNUZ_FP8 0
15 #endif
16 
17 #ifdef CK_USE_OCP_FP8
18 #define CK_USE_OCP_FP8 1
19 #else
20 #define CK_USE_OCP_FP8 0
21 #endif
22 
23 #if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
24  defined(__gfx1201__) || defined(__gfx950__)) && \
25  __HIP_DEVICE_COMPILE__
26 #define CK_FP8_CVT_FAST_PATH 1
27 #else
28 #define CK_FP8_CVT_FAST_PATH 0
29 #endif
30 
31 #if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
32 #define CK_OCP_FP8_CVT_FAST_PATH 1
33 #else
34 #define CK_OCP_FP8_CVT_FAST_PATH 0
35 #endif
36 
37 namespace ck {
38 
39 using f8_fnuz_t = _BitInt(8);
40 using bf8_fnuz_t = unsigned _BitInt(8);
41 
42 typedef unsigned char fp8_storage_t;
43 
48 {
49  CK_E4M3_OCP = 0, // OCP E4M3
50  CK_E5M2_OCP = 1, // OCP E5M2
51  CK_E4M3_FNUZ = 2, // FP8
52  CK_E5M2_FNUZ = 3, // BF8
53 };
54 
58 enum class ck_saturation_t
59 {
60  CK_NOSAT = 0, // No saturation - replace with NaN or Inf
61  CK_SATFINITE = 1, // Saturate to finite
62 };
63 
64 namespace fp8_impl {
65 
66 typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2)));
67 typedef float float2_t __attribute__((ext_vector_type(2)));
68 
69 __host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a)
70 {
71  return static_cast<unsigned char>(a) == 0x80;
72 }
73 __host__ __device__ static inline constexpr bool fnuz_bf8_is_nan(bf8_fnuz_t a)
74 {
75  return static_cast<unsigned char>(a) == 0x80;
76 }
77 
78 __host__ __device__ static inline constexpr bool ocp_f8_is_nan(fp8_storage_t a)
79 {
80  return (a & 0x7f) == 0x7f;
81 }
82 __host__ __device__ static inline constexpr bool ocp_bf8_is_nan(fp8_storage_t a)
83 {
84  return (a & 0x7f) > 0x7c;
85 }
86 
87 // The conversion function is from rocblas
88 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
89 // This has been modified to handle double types as well
90 template <typename T, int wm, int we, bool is_fnuz, bool clip = false>
91 __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
92 {
93  constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
94  constexpr bool is_float = __hip_internal::is_same<T, float>::value;
95  constexpr bool is_double = __hip_internal::is_same<T, double>::value;
96  static_assert(is_half || is_float || is_double, "only half, float and double are supported");
97 
98  constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
99  constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
100 
101  T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
102  if constexpr(is_half)
103  {
104  const unsigned short int ihInf = 0x7C00;
105  const unsigned short int ihNegInf = 0xFC00;
106  const unsigned short int ihNaN = 0x7C01;
107  const unsigned short int ihNeg0 = 0x8000;
108  /* Max number in e5m2 57344*/
109  const unsigned short int ifmax = 0x7B00;
110  const unsigned short int ifmin = 0xFB00;
111 
112  fInf = bit_cast<_Float16>(ihInf);
113  fNegInf = bit_cast<_Float16>(ihNegInf);
114  fNaN = bit_cast<_Float16>(ihNaN);
115  fNeg0 = bit_cast<_Float16>(ihNeg0);
116  fmax = bit_cast<_Float16>(ifmax);
117  fmin = bit_cast<_Float16>(ifmin);
118  }
119  else if constexpr(is_float)
120  {
121  const unsigned int ifInf = 0x7F800000;
122  const unsigned int ifNegInf = 0xFF800000;
123  const unsigned int ifNaN = 0x7F800001;
124  const unsigned int ifNeg0 = 0x80000000;
125  /* Max number in e5m2 57344*/
126  const unsigned int ifmax = 0x47600000;
127  const unsigned int ifmin = 0xC7600000;
128 
129  fInf = bit_cast<float>(ifInf);
130  fNegInf = bit_cast<float>(ifNegInf);
131  fNaN = bit_cast<float>(ifNaN);
132  fNeg0 = bit_cast<float>(ifNeg0);
133  fmax = bit_cast<float>(ifmax);
134  fmin = bit_cast<float>(ifmin);
135  }
136  else if constexpr(is_double)
137  {
138  const unsigned long long ifInf = 0x7FF0000000000000ull;
139  const unsigned long long ifNegInf = 0xFFF0000000000000ull;
140  const unsigned long long ifNaN = 0x7FF0000000000001ull;
141  const unsigned long long ifNeg0 = 0x8000000000000000ull;
142  /* Max number in e5m2 57344*/
143  const unsigned long long ifmax = 0x40EC000000000000ull;
144  const unsigned long long ifmin = 0xC0EC000000000000ull;
145 
146  fInf = bit_cast<double>(ifInf);
147  fNegInf = bit_cast<double>(ifNegInf);
148  fNaN = bit_cast<double>(ifNaN);
149  fNeg0 = bit_cast<double>(ifNeg0);
150  fmax = bit_cast<double>(ifmax);
151  fmin = bit_cast<double>(ifmin);
152  }
153 
154  if(x == 0)
155  {
156  return 0;
157  }
158 
159  unsigned long long sign = x >> 7;
160  unsigned long long mantissa = x & ((1 << wm) - 1);
161  int exponent = (x & 0x7F) >> wm;
162  if constexpr(is_fnuz)
163  {
164  if(x == 0x80)
165  {
166  return fNaN;
167  }
168  }
169  else
170  {
171  if(x == 0x80)
172  {
173  return fNeg0;
174  }
175  if constexpr(we == 4)
176  { // e4m3
177  if((x & 0x7F) == 0x7F)
178  {
179  return fNaN;
180  }
181  }
182  else if((x & 0x7C) == 0x7C)
183  { // e5m2
184  if((x & 0x3) == 0)
185  {
186  if constexpr(clip)
187  {
188  return sign ? fmin : fmax;
189  }
190  return sign ? fNegInf : fInf;
191  }
192  return fNaN;
193  }
194  }
195 
196  typename std::conditional<
197  sizeof(T) == 2,
198  unsigned short int,
199  typename std::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type
200  retval;
201 
202  if constexpr(we == 5 && is_half && !is_fnuz)
203  {
204  retval = x << 8;
205  return bit_cast<T>(retval);
206  }
207 
208  const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
209 
210  // subnormal input
211  if(exponent == 0)
212  {
213 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
214  // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
215  int sh = 1 + __clz(mantissa) - (32 - wm);
216 #else
217  int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
218 #endif
219  mantissa <<= sh;
220  exponent += 1 - sh;
221  mantissa &= ((1ull << wm) - 1);
222  }
223  exponent += exp_low_cutoff - 1;
224  mantissa <<= wmo - wm;
225 
226  // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
227  if(exponent <= 0)
228  {
229  mantissa |= 1 << wmo;
230  mantissa >>= 1 - exponent;
231  exponent = 0;
232  }
233 
234  if constexpr(sizeof(T) == 2)
235  retval = (sign << 15) | (exponent << 10) | mantissa;
236  else if constexpr(sizeof(T) == 4)
237  retval = (sign << 31) | (exponent << 23) | mantissa;
238  else
239  retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
240 
241  return bit_cast<T>(retval);
242 }
243 
244 #if CK_FP8_CVT_FAST_PATH
245 template <ck_fp8_interpretation_t interpret>
246 static __device__ float cast_to_f32_from_f8(fp8_storage_t v)
247 {
248  union
249  {
250  unsigned int i32val;
251  unsigned char i8val[4];
252  } val;
253  val.i8val[0] = v;
254 
255  static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
259  "Only FNUZ and OCP interpretations are supported");
260 
261  if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
263  {
264  return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
265  }
266  else
267  {
268  return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
269  }
270 }
271 
272 template <ck_fp8_interpretation_t interpret>
273 static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
274 {
275  const auto i16val = bit_cast<uint16_t>(v);
276 
277  static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
281  "Only FNUZ and OCP interpretations are supported");
282 
283  if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
285  {
286  return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false);
287  }
288  else
289  {
290  return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false);
291  }
292 }
293 #endif
294 
295 } // namespace fp8_impl
296 
297 struct f8_ocp_t
298 {
301 
302  static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
303  static constexpr ck_fp8_interpretation_t default_interpret =
305 
306  static constexpr unsigned int we = 4; // exponent width
307  static constexpr unsigned int wm = 3; // mantissa width
308 
309  __host__ __device__ constexpr bool operator==(const f8_ocp_t& other) const
310  {
311  return (data == other.data) && (fp8_impl::ocp_f8_is_nan(data) == false); // NaN != NaN
312  }
313 
314 #if CK_USE_OCP_FP8
315  __host__ __device__ explicit operator float() const
316 #else
317  __host__ explicit operator float() const
318 #endif
319  {
320 #if CK_OCP_FP8_CVT_FAST_PATH
321  return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
322 #else
323  return fp8_impl::cast_from_f8<float, wm, we, false>(
324  this->data); // XXX: clip==false must be consistent with operator _Float16
325 #endif
326  }
327 
328 #if CK_USE_OCP_FP8
329  __host__ __device__ explicit operator _Float16() const
330 #else
331  __host__ explicit operator _Float16() const
332 #endif
333  {
334 #if CK_OCP_FP8_CVT_FAST_PATH
335  return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
336 #else
337  return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
338  this->data); // XXX: clip==false must be consistent with operator float
339 #endif
340  }
341 };
342 
343 struct bf8_ocp_t
344 {
347 
348  static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
349  static constexpr ck_fp8_interpretation_t default_interpret =
351 
352  static constexpr unsigned int we = 5; // exponent width
353  static constexpr unsigned int wm = 2; // mantissa width
354 
355  __host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const
356  {
357  return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN
358  }
359 
360 #if CK_USE_OCP_FP8
361  __host__ __device__ explicit operator float() const
362 
363 #else
364  __host__ explicit operator float() const
365 #endif
366  {
367 #if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
368  return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
369 #else
370  return fp8_impl::cast_from_f8<float, wm, we, false>(
371  this->data); // XXX: clip==false must be consistent with operator _Float16
372 #endif
373  }
374 
375 #if CK_USE_OCP_FP8
376  __host__ __device__ explicit operator _Float16() const
377 #else
378  __host__ explicit operator _Float16() const
379 #endif
380  {
381 #if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
382  return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
383 #else
384  return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
385  this->data); // XXX: clip==false must be consistent with operator float
386 #endif
387  }
388 };
389 
390 template <typename T>
391 __host__ __device__ static inline constexpr bool fp8_is_nan(T);
392 
393 template <>
394 __host__ __device__ inline constexpr bool fp8_is_nan(f8_ocp_t a)
395 {
396  return fp8_impl::ocp_f8_is_nan(a.data);
397 }
398 template <>
399 __host__ __device__ inline constexpr bool fp8_is_nan(bf8_ocp_t a)
400 {
401  return fp8_impl::ocp_bf8_is_nan(a.data);
402 }
403 template <>
404 __host__ __device__ inline constexpr bool fp8_is_nan(f8_fnuz_t a)
405 {
406  return fp8_impl::fnuz_f8_is_nan(a);
407 }
408 template <>
409 __host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a)
410 {
411  return fp8_impl::fnuz_bf8_is_nan(a);
412 }
413 
414 template <typename T,
415  ck::enable_if_t<is_same_v<T, bf8_ocp_t> || is_same_v<T, f8_ocp_t> ||
416  is_same_v<T, bf8_fnuz_t> || is_same_v<T, f8_fnuz_t>,
417  bool> = true>
418 __host__ __device__ static inline constexpr bool fp8_is_inf(T)
419 {
420  return false;
421 }
422 template <>
423 __host__ __device__ inline constexpr bool fp8_is_inf(bf8_ocp_t a)
424 {
425  return (a.data & 0x7f) == 0x7c;
426 }
427 
428 namespace fp8_impl {
429 
430 // Assertions to check for supported conversion types
431 #define __assert_ocp_support(interp) \
432  { \
433  if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
434  interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
435  { \
436  __hip_assert(false && "type is unsupported by current target device"); \
437  } \
438  }
439 #define __assert_fnuz_support(interp) \
440  { \
441  if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
442  interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
443  { \
444  __hip_assert(false && "type is unsupported by current target device"); \
445  } \
446  }
447 
448 __host__ __device__ static inline void
449 __is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp)
450 {
451 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
452 #if CK_USE_OCP_FP8
453  __assert_ocp_support(interp);
454 #endif
455 #if CK_USE_FNUZ_FP8
456  __assert_fnuz_support(interp);
457 #endif
458 #endif
459 }
460 
461 #if CK_FP8_CVT_FAST_PATH
462 // The conversion function is from rocblas
463 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
464 template <ck_fp8_interpretation_t interpret, bool saturate, bool stochastic_rounding = false>
465 static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
466 {
467  fp8_storage_t i8data;
468  union
469  {
470  float fval;
471  unsigned int i32val;
472  unsigned char i8val[4]; // NOTE: not endian independent
473  } val;
474 
475  unsigned int ival = 0;
476  val.fval = v;
477 
478  if constexpr(saturate)
479  {
480  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
481  {
482  if((val.i32val & 0x7F800000) != 0x7F800000)
483  {
484  val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
485  }
486  }
487  else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
488  { // OCP type
489  if((val.i32val & 0x7F800000) != 0x7F800000)
490  {
491  val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
492  }
493  }
494  else
495  {
496  if((val.i32val & 0x7F800000) != 0x7F800000)
497  {
498  val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
499  }
500  }
501  }
502 
503  if constexpr(stochastic_rounding)
504  {
505  ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
507  ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
508  : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
509  val.i32val = ival;
510  i8data = val.i8val[0]; // little endian
511  }
512  else
513  { // RNE CVT
514  ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
516  ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
517  : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
518  val.fval,
519  ival,
520  false); // false -> WORD0
521  val.i32val = ival;
522  i8data = val.i8val[0];
523  }
524  return i8data;
525 }
526 #endif // CK_FP8_CVT_FAST_PATH
527 
528 // The conversion function is from rocblas
529 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
530 // This has been modified to add double types conversion as well
531 template <typename T, int wm, int we, bool is_fnuz, bool clip = false, bool stoch = false>
532 __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rng = 0)
533 {
534  constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
535  constexpr bool is_float = __hip_internal::is_same<T, float>::value;
536  constexpr bool is_double = __hip_internal::is_same<T, double>::value;
537  static_assert(is_half || is_float || is_double,
538  "Only half, float and double can be cast to f8");
539 
540  constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
541 
542  using T_bitwise = typename std::conditional<
543  sizeof(T) == 2,
544  unsigned short int,
545  typename std::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type;
546  T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
547 
548  unsigned long long x{x_bitwise};
549 
550  unsigned long long head, mantissa;
551  int exponent, bias;
552  unsigned int sign;
553  unsigned long long fInf, mask;
554 
555  if constexpr(sizeof(T) == 8)
556  {
557  head = x & 0xFFF0000000000000ull;
558  mantissa = x & 0xFFFFFFFFFFFFFull;
559  exponent = (head >> 52) & 0x7FF;
560  sign = head >> 63;
561  bias = 1023;
562  fInf = 0x7FF0000000000000ull;
563  mask = 0x7FFFFFFFFFFFFFFFull;
564  }
565  else if constexpr(sizeof(T) == 4)
566  {
567  head = x & 0xFF800000;
568  mantissa = x & 0x7FFFFF;
569  exponent = (head >> 23) & 0xFF;
570  sign = head >> 31;
571  bias = 127;
572  fInf = 0x7F800000;
573  mask = 0x7FFFFFFF;
574  }
575  else
576  {
577  head = x & 0xFC00;
578  mantissa = x & 0x3FF;
579  exponent = (head >> 10) & 0x1F;
580  sign = head >> 15;
581  bias = 15;
582  fInf = 0x7C00;
583  mask = 0x7FFF;
584  }
585  unsigned int signed_inf = 0;
586  unsigned int nan = 0;
587  if constexpr(is_fnuz)
588  {
589  signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
590  nan = 0x80;
591  }
592  else
593  {
594  if constexpr(we == 4)
595  { // e4m3
596  signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
597  }
598  else
599  { // e5m2
600  signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
601  }
602  nan = (sign << 7) + 0x7f;
603  }
604  // Max values
605  unsigned long long ifmax = 0;
606  if constexpr(sizeof(T) == 8)
607  {
608  if constexpr(we == 5)
609  { // 57344
610  ifmax = 0x40EC000000000000ull;
611  }
612  else
613  {
614  if constexpr(is_fnuz)
615  { // 240
616  ifmax = 0x406E000000000000ull;
617  }
618  else
619  { // 448
620  ifmax = 0x407C000000000000ull;
621  }
622  }
623  }
624  else if(sizeof(T) == 4)
625  {
626  if constexpr(we == 5)
627  {
628  ifmax = 0x47600000;
629  }
630  else
631  {
632  if constexpr(is_fnuz)
633  {
634  ifmax = 0x43700000;
635  }
636  else
637  {
638  ifmax = 0x43E00000;
639  }
640  }
641  }
642  else
643  {
644  if constexpr(we == 5)
645  {
646  ifmax = 0x7B00;
647  }
648  else
649  {
650  if constexpr(is_fnuz)
651  {
652  ifmax = 0x5B80;
653  }
654  else
655  {
656  ifmax = 0x5F00;
657  }
658  }
659  }
660  // Deal with inf and NaNs
661  if((x & fInf) == fInf)
662  {
663  if constexpr(is_fnuz)
664  return signed_inf;
665 
666  return mantissa != 0 ? nan : signed_inf;
667  }
668 
669  if((x & mask) > ifmax)
670  {
671  return signed_inf;
672  }
673 
674  if(x == 0)
675  {
676  return 0;
677  }
678 
679  // First need to check if it is normal or denorm as there is a difference of
680  // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
681  // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
682  // to mantissa and truncate. And for RNE, no need to add rng. Then probably
683  // need to check whether there is carry and adjust exponent and mantissa again
684 
685  // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
686  // bits
687  const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
688  const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
689  // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
690  // f8_exponent is the converted f8 exponent with bias encoding
691  // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
692  // the difference needs to be adjusted and mantissa shifted
693  int act_exponent, f8_exponent, exponent_diff;
694 
695  if(exponent == 0)
696  { // fp32/fp16 is in denormal.
697  /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
698  mostly concern fp16 here. In this case, f8 is usually in denormal. But there
699  could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
700  exponent bias 16. It means that there are some numbers in fp16 denormal but they
701  are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
702  where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
703  (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
704  act_exponent = exponent - bias + 1;
705  exponent_diff = f8_denormal_act_exponent -
706  act_exponent; // actual exponent is exponent-bias+1 as it is denormal
707  }
708  else
709  { // fp32/fp16 is normal with implicit 1
710  act_exponent = exponent - bias;
711  if(act_exponent <= f8_denormal_act_exponent)
712  {
713  /* This is the case where fp32/fp16 is normal but it is in f8 denormal
714  range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
715  actual exponent is -7, it is actually larger due to the implicit 1,
716  Therefore it needs to be adjust to -6 and mantissa shift right by 1.
717  So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
718  exponent_diff = f8_denormal_act_exponent - act_exponent;
719  }
720  else
721  { // both fp32/fp16 and f8 are in normal range
722  exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
723  // for this case, act_exponent could be larger. Just
724  // that it does not need shift mantissa
725  }
726  mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa
727  }
728 
729  bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
730  (1ull << (mfmt - wm + exponent_diff - 1));
731  /* This part is a bit tricky. The judgment of whether it is a tie needs to be
732  done before we shift right as shift right could rip off some residual part and
733  make something not midpoint look like midpoint. For example, the fp16 number
734  0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
735  by 4 bits, it would look like midpoint.
736  */
737 
738  if(exponent_diff > 0)
739  mantissa >>= exponent_diff;
740  else if(exponent_diff == -1)
741  mantissa <<= -exponent_diff;
742  bool implicit_one = mantissa & (1ull << mfmt);
743  // if there is no implicit 1, it means the f8 is denormal and need to adjust
744  // to denorm exponent
745  f8_exponent =
746  (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
747 
748  // Now we have the exponent and mantissa adjusted
749  unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
750  bool odd =
751  mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1
752  mantissa +=
753  (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
754 
755  // Now we deal with overflow
756  if(f8_exponent == 0)
757  {
758  if((1ull << mfmt) & mantissa)
759  {
760  f8_exponent = 1; // denormal overflow to become normal, promote exponent
761  }
762  }
763  else
764  {
765  if((1ull << (mfmt + 1)) & mantissa)
766  {
767  mantissa >>= 1;
768  f8_exponent++;
769  }
770  }
771 
772  mantissa >>= (mfmt - wm);
773 
774  // above range: quantize to maximum possible float of the same sign
775  const int max_exp = (1 << we) - 1;
776  if(f8_exponent > max_exp)
777  {
778  if constexpr(clip)
779  {
780  mantissa = (1 << wm) - 1;
781  f8_exponent = max_exp;
782  }
783  else
784  {
785  return signed_inf;
786  }
787  }
788 
789  if(f8_exponent == 0 && mantissa == 0)
790  return is_fnuz ? 0 : (sign << 7);
791  mantissa &= (1 << wm) - 1;
792  return (sign << 7) | (f8_exponent << wm) | mantissa;
793 }
794 
803 template <ck_fp8_interpretation_t interp,
805  bool stochastic_rounding = false>
806 #if CK_FP8_CVT_FAST_PATH
807 __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
808 {
809  __is_interpret_supported(interp);
810  uint32_t rng = 0;
811  if constexpr(stochastic_rounding)
812  {
813  constexpr int seed = 1254739;
814 #ifndef CK_CODE_GEN_RTC
815  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
816 #else
817  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
818 #endif
819  }
820  return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
821  f, rng);
822 #else
823 #if CK_USE_OCP_FP8
824 __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
825 {
826 #else
827 __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
828 {
829 #endif
830  uint32_t rng = 0;
831  if constexpr(stochastic_rounding)
832  {
833  constexpr int seed = 1254739;
834 #ifndef CK_CODE_GEN_RTC
835  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
836 #else
837  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
838 #endif
839  }
840 
841  if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
842  {
843  return cast_to_f8<float,
844  3,
845  4,
846  true,
848  stochastic_rounding>(f, rng);
849  }
850  else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_FNUZ)
851  {
852  return cast_to_f8<float,
853  2,
854  5,
855  true,
857  stochastic_rounding>(f, rng);
858  }
859  else if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
860  {
861  return cast_to_f8<float,
862  3,
863  4,
864  false,
866  stochastic_rounding>(f, rng);
867  }
868  else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
869  {
870  return cast_to_f8<float,
871  2,
872  5,
873  false,
875  stochastic_rounding>(f, rng);
876  }
877  else
878  {
879  __hip_assert(false && "FP8 type is not supported by current target device");
880  return 0;
881  }
882 #endif // CK_FP8_CVT_FAST_PATH
883 }
884 
894 template <ck_fp8_interpretation_t interp,
896  bool stochastic_rounding = false>
897 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
898 __host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
899 #else
900 __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
901 #endif
902 {
903  return cvt_float_to_fp8<interp, sat, stochastic_rounding>(static_cast<float>(x));
904 }
905 
906 } // namespace fp8_impl
907 
908 // Declare a template function for fp8 conversion using RNE
909 template <typename Y, typename X>
910 __host__ __device__ constexpr Y f8_convert_rne(X x);
911 
912 // convert fp32 to fp8 with rounding to nearest even
913 template <>
914 inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
915 {
916  return f8_ocp_t{
917  fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
918 }
919 
920 // convert fp32 to bf8 with rounding to nearest even
921 template <>
922 inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, float>(float x)
923 {
924  return bf8_ocp_t{
925  fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(x)};
926 }
927 
928 // convert _Float16 to fp8 with rounding to nearest even
929 template <>
930 inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, _Float16>(_Float16 x)
931 {
932  return f8_ocp_t{
933  fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
934 }
935 
936 template <>
937 inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, _Float16>(_Float16 x)
938 {
939  return bf8_ocp_t{
940  fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
941  x)};
942 }
943 
944 // Declare a template function for fp8 conversion using RNE
945 template <typename Y, typename X>
946 __host__ __device__ constexpr Y f8_convert_sr(X x);
947 
948 // convert fp32 to fp8 with stochastic rounding
949 template <>
950 inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
951 {
952  return f8_ocp_t{
953  fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
954  x)};
955 }
956 
957 // convert fp32 to bf8 with stochastic rounding
958 template <>
959 inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, float>(float x)
960 {
961  return bf8_ocp_t{fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret,
963  true>(x)};
964 }
965 
966 // convert _Float16 to fp8 with stochastic rounding
967 template <>
968 inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, _Float16>(_Float16 x)
969 {
970  return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
972  true>(x)};
973 }
974 
975 // convert _Float16 to bf8 with stochastic rounding
976 template <>
977 inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, _Float16>(_Float16 x)
978 {
979  return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret,
981  true>(x)};
982 }
983 
984 #if CK_USE_OCP_FP8
985 using f8_t = f8_ocp_t;
986 using bf8_t = bf8_ocp_t;
987 #define CK_FP8_TYPE_FNUZ 0
988 #define CK_FP8_TYPE_OCP 1
989 #else
990 using f8_t = f8_fnuz_t;
992 #define CK_FP8_TYPE_FNUZ 1
993 #define CK_FP8_TYPE_OCP 0
994 #endif
995 
996 } // namespace ck
#define __assert_ocp_support(interp)
Definition: amd_ck_fp8.hpp:431
#define __assert_fnuz_support(interp)
Definition: amd_ck_fp8.hpp:439
float float2_t
Definition: amd_ck_fp8.hpp:67
fp8_storage_t fp8x2_storage_t
Definition: amd_ck_fp8.hpp:66
Definition: ck.hpp:264
__host__ constexpr __device__ Y bit_cast(const X &x)
Definition: type.hpp:309
__host__ __device__ bf8_ocp_t f8_convert_sr< bf8_ocp_t, _Float16 >(_Float16 x)
Definition: amd_ck_fp8.hpp:977
__host__ __device__ bf8_ocp_t f8_convert_rne< bf8_ocp_t, _Float16 >(_Float16 x)
Definition: amd_ck_fp8.hpp:937
bf8_fnuz_t bf8_t
Definition: amd_ck_fp8.hpp:991
__host__ __device__ f8_ocp_t f8_convert_rne< f8_ocp_t, float >(float x)
Definition: amd_ck_fp8.hpp:914
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:990
ck_fp8_interpretation_t
Describes FP8 interpretation.
Definition: amd_ck_fp8.hpp:48
__host__ __device__ bf8_ocp_t f8_convert_sr< bf8_ocp_t, float >(float x)
Definition: amd_ck_fp8.hpp:959
__host__ constexpr __device__ Y f8_convert_rne(X x)
__host__ __device__ f8_ocp_t f8_convert_rne< f8_ocp_t, _Float16 >(_Float16 x)
Definition: amd_ck_fp8.hpp:930
unsigned _BitInt(8) bf8_fnuz_t
Definition: amd_ck_fp8.hpp:40
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
__host__ __device__ f8_ocp_t f8_convert_sr< f8_ocp_t, _Float16 >(_Float16 x)
Definition: amd_ck_fp8.hpp:968
_BitInt(8) f8_fnuz_t
Definition: amd_ck_fp8.hpp:39
__host__ __device__ bf8_ocp_t f8_convert_rne< bf8_ocp_t, float >(float x)
Definition: amd_ck_fp8.hpp:922
ck_saturation_t
Describes saturation behavior.
Definition: amd_ck_fp8.hpp:59
__host__ __device__ f8_ocp_t f8_convert_sr< f8_ocp_t, float >(float x)
Definition: amd_ck_fp8.hpp:950
__host__ constexpr __device__ Y f8_convert_sr(X x)
unsigned char fp8_storage_t
Definition: amd_ck_fp8.hpp:42
Definition: amd_ck_fp8.hpp:344
__host__ constexpr __device__ bool operator==(const bf8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:355
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:345
data_type data
Definition: amd_ck_fp8.hpp:346
static constexpr ck_fp8_interpretation_t default_interpret
Definition: amd_ck_fp8.hpp:349
static constexpr ck_saturation_t default_saturation
Definition: amd_ck_fp8.hpp:348
Definition: amd_ck_fp8.hpp:298
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:299
data_type data
Definition: amd_ck_fp8.hpp:300
__host__ constexpr __device__ bool operator==(const f8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:309
static constexpr ck_fp8_interpretation_t default_interpret
Definition: amd_ck_fp8.hpp:303
static constexpr ck_saturation_t default_saturation
Definition: amd_ck_fp8.hpp:302