/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_ck_fp8.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_ck_fp8.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_ck_fp8.hpp Source File
amd_ck_fp8.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/ck.hpp"
8 #include "ck/utility/get_id.hpp"
11 #include "ck/utility/type.hpp"
12 
13 #ifndef CK_USE_FNUZ_FP8
14 #define CK_USE_FNUZ_FP8 0
15 #endif
16 
17 #ifndef CK_USE_OCP_FP8
18 #define CK_USE_OCP_FP8 0
19 #endif
20 
21 #if(defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
22 #define CK_FP8_CVT_FAST_PATH 1
23 #else
24 #define CK_FP8_CVT_FAST_PATH 0
25 #endif
26 
27 #if(defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
28 #define CK_OCP_FP8_CVT_FAST_PATH 1
29 #else
30 #define CK_OCP_FP8_CVT_FAST_PATH 0
31 #endif
32 
33 namespace ck {
34 
35 struct f8_fnuz_t
36 {
37  using data_type = unsigned char;
39  __host__ __device__ explicit constexpr f8_fnuz_t(data_type in_data) : m_data(in_data) {}
40  __host__ __device__ explicit constexpr f8_fnuz_t() = default;
41  __host__ __device__ bool constexpr operator==(f8_fnuz_t other) const
42  {
43  return m_data == other.m_data;
44  }
45  __host__ __device__ explicit constexpr operator data_type() const { return m_data; }
46 };
47 
48 struct bf8_fnuz_t
49 {
50  using data_type = unsigned char;
52  __host__ __device__ explicit constexpr bf8_fnuz_t(data_type in_data) : m_data(in_data) {}
53  __host__ __device__ explicit constexpr bf8_fnuz_t() = default;
54  __host__ __device__ bool constexpr operator==(bf8_fnuz_t other) const
55  {
56  return m_data == other.m_data;
57  }
58  __host__ __device__ explicit constexpr operator data_type() const { return m_data; }
59 };
60 
61 static_assert(1 == sizeof(f8_fnuz_t));
62 static_assert(1 == sizeof(bf8_fnuz_t));
63 
64 typedef unsigned char fp8_storage_t;
65 
70 {
71  CK_E4M3_OCP = 0, // OCP E4M3
72  CK_E5M2_OCP = 1, // OCP E5M2
73  CK_E4M3_FNUZ = 2, // FP8
74  CK_E5M2_FNUZ = 3, // BF8
75 };
76 
80 enum class ck_saturation_t
81 {
82  CK_NOSAT = 0, // No saturation - replace with NaN or Inf
83  CK_SATFINITE = 1, // Saturate to finite
84 };
85 
86 namespace fp8_impl {
87 
88 typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2)));
89 typedef _Float16 half2_t __attribute__((ext_vector_type(2)));
90 typedef ushort ushortx2_t __attribute__((ext_vector_type(2)));
91 typedef short shortx2_t __attribute__((ext_vector_type(2)));
92 typedef float float2_t __attribute__((ext_vector_type(2)));
93 
94 __host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a)
95 {
96  return static_cast<unsigned char>(a) == 0x80;
97 }
98 __host__ __device__ static inline constexpr bool fnuz_bf8_is_nan(bf8_fnuz_t a)
99 {
100  return static_cast<unsigned char>(a) == 0x80;
101 }
102 
103 __host__ __device__ static inline constexpr bool ocp_f8_is_nan(fp8_storage_t a)
104 {
105  return (a & 0x7f) == 0x7f;
106 }
107 __host__ __device__ static inline constexpr bool ocp_bf8_is_nan(fp8_storage_t a)
108 {
109  return (a & 0x7f) > 0x7c;
110 }
111 
112 // The conversion function is from rocblas
113 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
114 // This has been modified to handle double types as well
115 template <typename T, int wm, int we, bool is_fnuz, bool clip = false>
116 __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
117 {
118  constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
119  constexpr bool is_float = __hip_internal::is_same<T, float>::value;
120  constexpr bool is_double = __hip_internal::is_same<T, double>::value;
121  static_assert(is_half || is_float || is_double, "only half, float and double are supported");
122 
123  constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
124  constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
125 
126  T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
127  if constexpr(is_half)
128  {
129  const unsigned short int ihInf = 0x7C00;
130  const unsigned short int ihNegInf = 0xFC00;
131  const unsigned short int ihNaN = 0x7C01;
132  const unsigned short int ihNeg0 = 0x8000;
133  /* Max number in e5m2 57344*/
134  const unsigned short int ifmax = 0x7B00;
135  const unsigned short int ifmin = 0xFB00;
136 
137  fInf = bit_cast<_Float16>(ihInf);
138  fNegInf = bit_cast<_Float16>(ihNegInf);
139  fNaN = bit_cast<_Float16>(ihNaN);
140  fNeg0 = bit_cast<_Float16>(ihNeg0);
141  fmax = bit_cast<_Float16>(ifmax);
142  fmin = bit_cast<_Float16>(ifmin);
143  }
144  else if constexpr(is_float)
145  {
146  const unsigned int ifInf = 0x7F800000;
147  const unsigned int ifNegInf = 0xFF800000;
148  const unsigned int ifNaN = 0x7F800001;
149  const unsigned int ifNeg0 = 0x80000000;
150  /* Max number in e5m2 57344*/
151  const unsigned int ifmax = 0x47600000;
152  const unsigned int ifmin = 0xC7600000;
153 
154  fInf = bit_cast<float>(ifInf);
155  fNegInf = bit_cast<float>(ifNegInf);
156  fNaN = bit_cast<float>(ifNaN);
157  fNeg0 = bit_cast<float>(ifNeg0);
158  fmax = bit_cast<float>(ifmax);
159  fmin = bit_cast<float>(ifmin);
160  }
161  else if constexpr(is_double)
162  {
163  const unsigned long long ifInf = 0x7FF0000000000000ull;
164  const unsigned long long ifNegInf = 0xFFF0000000000000ull;
165  const unsigned long long ifNaN = 0x7FF0000000000001ull;
166  const unsigned long long ifNeg0 = 0x8000000000000000ull;
167  /* Max number in e5m2 57344*/
168  const unsigned long long ifmax = 0x40EC000000000000ull;
169  const unsigned long long ifmin = 0xC0EC000000000000ull;
170 
171  fInf = bit_cast<double>(ifInf);
172  fNegInf = bit_cast<double>(ifNegInf);
173  fNaN = bit_cast<double>(ifNaN);
174  fNeg0 = bit_cast<double>(ifNeg0);
175  fmax = bit_cast<double>(ifmax);
176  fmin = bit_cast<double>(ifmin);
177  }
178 
179  if(x == 0)
180  {
181  return 0;
182  }
183 
184  unsigned long long sign = x >> 7;
185  unsigned long long mantissa = x & ((1 << wm) - 1);
186  int exponent = (x & 0x7F) >> wm;
187  if constexpr(is_fnuz)
188  {
189  if(x == 0x80)
190  {
191  return fNaN;
192  }
193  }
194  else
195  {
196  if(x == 0x80)
197  {
198  return fNeg0;
199  }
200  if constexpr(we == 4)
201  { // e4m3
202  if((x & 0x7F) == 0x7F)
203  {
204  return fNaN;
205  }
206  }
207  else if((x & 0x7C) == 0x7C)
208  { // e5m2
209  if((x & 0x3) == 0)
210  {
211  if constexpr(clip)
212  {
213  return sign ? fmin : fmax;
214  }
215  return sign ? fNegInf : fInf;
216  }
217  return fNaN;
218  }
219  }
220 
221  typename ck::conditional_t<
222  sizeof(T) == 2,
223  unsigned short int,
224  typename ck::conditional_t<sizeof(T) == 4, unsigned int, unsigned long long>>
225  retval;
226 
227  if constexpr(we == 5 && is_half && !is_fnuz)
228  {
229  retval = x << 8;
230  return bit_cast<T>(retval);
231  }
232 
233  const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
234 
235  // subnormal input
236  if(exponent == 0)
237  {
238 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
239  // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
240  int sh = 1 + __clz(mantissa) - (32 - wm);
241 #else
242  int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
243 #endif
244  mantissa <<= sh;
245  exponent += 1 - sh;
246  mantissa &= ((1ull << wm) - 1);
247  }
248  exponent += exp_low_cutoff - 1;
249  mantissa <<= wmo - wm;
250 
251  // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
252  if(exponent <= 0)
253  {
254  mantissa |= 1 << wmo;
255  mantissa >>= 1 - exponent;
256  exponent = 0;
257  }
258 
259  if constexpr(sizeof(T) == 2)
260  retval = (sign << 15) | (exponent << 10) | mantissa;
261  else if constexpr(sizeof(T) == 4)
262  retval = (sign << 31) | (exponent << 23) | mantissa;
263  else
264  retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
265 
266  return bit_cast<T>(retval);
267 }
268 
269 #if CK_FP8_CVT_FAST_PATH
270 template <ck_fp8_interpretation_t interpret>
271 static __host__ __device__ float cast_to_f32_from_f8(fp8_storage_t v)
272 {
273  union
274  {
275  unsigned int i32val;
276  unsigned char i8val[4];
277  } val;
278  val.i8val[0] = v;
279 
280  static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
284  "Only FNUZ and OCP interpretations are supported");
285 
286  if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
288  {
289  return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
290  }
291  else
292  {
293  return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
294  }
295 }
296 
297 template <ck_fp8_interpretation_t interpret>
298 static __device__ float2_t cast_to_f32_from_f8(fp8x2_storage_t v)
299 {
300  const auto i16val = bit_cast<uint16_t>(v);
301 
302  static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
306  "Only FNUZ and OCP interpretations are supported");
307 
308  if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
310  {
311  return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false);
312  }
313  else
314  {
315  return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false);
316  }
317 }
318 #endif
319 
320 } // namespace fp8_impl
321 
322 struct f8_ocp_t
323 {
326 
327  static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
328  static constexpr ck_fp8_interpretation_t default_interpret =
330 
331  static constexpr unsigned int we = 4; // exponent width
332  static constexpr unsigned int wm = 3; // mantissa width
333 
334  __host__ __device__ constexpr bool operator==(const f8_ocp_t& other) const
335  {
336  return (data == other.data) && (fp8_impl::ocp_f8_is_nan(data) == false); // NaN != NaN
337  }
338 
339 #if CK_USE_OCP_FP8
340  __host__ __device__ explicit operator float() const
341 #else
342  __host__ explicit operator float() const
343 #endif
344  {
345 #if CK_OCP_FP8_CVT_FAST_PATH
346  return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
347 #else
348  return fp8_impl::cast_from_f8<float, wm, we, false>(
349  this->data); // XXX: clip==false must be consistent with operator _Float16
350 #endif
351  }
352 
353 #if CK_USE_OCP_FP8
354  __host__ __device__ explicit operator _Float16() const
355 #else
356  __host__ explicit operator _Float16() const
357 #endif
358  {
359 #if CK_OCP_FP8_CVT_FAST_PATH
360  return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
361 #else
362  return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
363  this->data); // XXX: clip==false must be consistent with operator float
364 #endif
365  }
366 };
367 
368 struct bf8_ocp_t
369 {
372 
373  static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
374  static constexpr ck_fp8_interpretation_t default_interpret =
376 
377  static constexpr unsigned int we = 5; // exponent width
378  static constexpr unsigned int wm = 2; // mantissa width
379 
380  __host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const
381  {
382  return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN
383  }
384 
385 #if CK_USE_OCP_FP8
386  __host__ __device__ explicit operator float() const
387 
388 #else
389  __host__ explicit operator float() const
390 #endif
391  {
392 #if defined(__gfx950__) || defined(__gfx12__)
393  return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
394 #else
395  return fp8_impl::cast_from_f8<float, wm, we, false>(
396  this->data); // XXX: clip==false must be consistent with operator _Float16
397 #endif
398  }
399 
400 #if CK_USE_OCP_FP8
401  __host__ __device__ explicit operator _Float16() const
402 #else
403  __host__ explicit operator _Float16() const
404 #endif
405  {
406 #if defined(__gfx950__) || defined(__gfx12__)
407  return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
408 #else
409  return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
410  this->data); // XXX: clip==false must be consistent with operator float
411 #endif
412  }
413 };
414 
415 template <typename T>
416 __host__ __device__ static inline constexpr bool fp8_is_nan(T);
417 
418 template <>
419 __host__ __device__ inline constexpr bool fp8_is_nan(f8_ocp_t a)
420 {
421  return fp8_impl::ocp_f8_is_nan(a.data);
422 }
423 template <>
424 __host__ __device__ inline constexpr bool fp8_is_nan(bf8_ocp_t a)
425 {
426  return fp8_impl::ocp_bf8_is_nan(a.data);
427 }
428 template <>
429 __host__ __device__ inline constexpr bool fp8_is_nan(f8_fnuz_t a)
430 {
431  return fp8_impl::fnuz_f8_is_nan(a);
432 }
433 template <>
434 __host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a)
435 {
436  return fp8_impl::fnuz_bf8_is_nan(a);
437 }
438 
439 template <typename T,
440  ck::enable_if_t<is_same_v<T, bf8_ocp_t> || is_same_v<T, f8_ocp_t> ||
441  is_same_v<T, bf8_fnuz_t> || is_same_v<T, f8_fnuz_t>,
442  bool> = true>
443 __host__ __device__ static inline constexpr bool fp8_is_inf(T)
444 {
445  return false;
446 }
447 template <>
448 __host__ __device__ inline constexpr bool fp8_is_inf(bf8_ocp_t a)
449 {
450  return (a.data & 0x7f) == 0x7c;
451 }
452 
453 namespace fp8_impl {
454 
455 // Assertions to check for supported conversion types
456 #define __fp8_impl_assert_ocp_support(interp) \
457  { \
458  if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
459  interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
460  { \
461  __hip_assert(false && "type is unsupported by current target device"); \
462  } \
463  }
464 #define __fp8_impl_assert_fnuz_support(interp) \
465  { \
466  if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
467  interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
468  { \
469  __hip_assert(false && "type is unsupported by current target device"); \
470  } \
471  }
472 
473 __host__ __device__ static inline void
474 __is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp)
475 {
476 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
477 #if CK_USE_OCP_FP8
479 #endif
480 #if CK_USE_FNUZ_FP8
482 #endif
483 #endif
484 }
485 
486 #if defined(__gfx950__)
487 template <ck_fp8_interpretation_t interpret,
488  bool saturate,
489  bool stochastic_rounding = false,
492 static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
493 {
494  union
495  {
496  unsigned int i32val;
497  half2_t half_vec;
498  fp8_storage_t i8val[4];
499  } val;
500 
501  constexpr unsigned int i32val = 0;
502  val.half_vec[0] = v;
503 
504  if constexpr(saturate)
505  {
506  if((val.i32val & 0x7FFF) != 0x7FFF)
507  {
508  val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
509  }
510  }
511 
512  val.i32val =
513  __builtin_amdgcn_cvt_scalef32_sr_fp8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0);
514 
515  return val.i8val[0];
516 }
517 
518 template <ck_fp8_interpretation_t interpret,
519  bool saturate,
520  bool stochastic_rounding = false,
523 static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
524 {
525  // there is no packed conversion with SR, so convert one element at a time
526  return fp8x2_storage_t{
527  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
528  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
529 }
530 
531 template <ck_fp8_interpretation_t interpret,
532  bool saturate,
533  bool stochastic_rounding = false,
536 static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
537 {
538  union
539  {
540  unsigned int i32val;
541  half2_t half_vec;
542  fp8_storage_t i8val[4];
543  } val;
544 
545  constexpr unsigned int i32val = 0;
546  val.half_vec[0] = v;
547 
548  if constexpr(saturate)
549  {
550  if((val.i32val & 0x7FFF) != 0x7FFF)
551  {
552  val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
553  }
554  }
555 
556  val.i32val =
557  __builtin_amdgcn_cvt_scalef32_sr_bf8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0);
558 
559  return val.i8val[0];
560 }
561 
562 template <ck_fp8_interpretation_t interpret,
563  bool saturate,
564  bool stochastic_rounding = false,
567 static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
568 {
569  // there is no packed conversion with SR, so convert one element at a time
570  return fp8x2_storage_t{
571  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
572  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
573 }
574 
575 template <ck_fp8_interpretation_t interpret,
576  bool saturate,
577  bool stochastic_rounding = false,
580 static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
581 {
582  ignore = rng;
583 
584  union
585  {
586  unsigned int i32val;
587  half2_t half_vec;
588  shortx2_t i16_vec;
589  fp8_storage_t i8val[4];
590  } val;
591 
592  constexpr shortx2_t i16x2val = {0, 0};
593  val.half_vec[0] = v;
594 
595  if constexpr(saturate)
596  {
597  if((val.i32val & 0x7FFF) != 0x7FFF)
598  {
599  val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
600  }
601  }
602 
603  val.i16_vec =
604  __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
605 
606  return val.i8val[0];
607 }
608 
609 template <ck_fp8_interpretation_t interpret,
610  bool saturate,
611  bool stochastic_rounding = false,
614 static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
615 {
616 #if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
617  return fp8x2_storage_t{
618  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
619  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
620 #else
621  ignore = rng;
622 
623  union
624  {
625  half2_t half_vec;
626  shortx2_t i16_vec;
627  fp8_storage_t i8val[4];
628  } val;
629 
630  constexpr shortx2_t i16x2val = {0, 0};
631  val.half_vec = v;
632 
633  if constexpr(saturate)
634  {
635  if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
636  {
637  val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
638  }
639  if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
640  {
641  val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 448.0, -448.0);
642  }
643  }
644 
645  val.i16_vec =
646  __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
647 
648  return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
649 #endif
650 }
651 
652 template <ck_fp8_interpretation_t interpret,
653  bool saturate,
654  bool stochastic_rounding = false,
657 static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
658 {
659  ignore = rng;
660 
661  union
662  {
663  unsigned int i32val;
664  half2_t half_vec;
665  shortx2_t i16_vec;
666  fp8_storage_t i8val[4];
667  } val;
668 
669  constexpr shortx2_t i16x2val = {0, 0};
670  val.half_vec[0] = v;
671 
672  if constexpr(saturate)
673  {
674  if((val.i32val & 0x7FFF) != 0x7FFF)
675  {
676  val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
677  }
678  }
679 
680  val.half_vec =
681  __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
682 
683  return val.i8val[0];
684 }
685 
686 template <ck_fp8_interpretation_t interpret,
687  bool saturate,
688  bool stochastic_rounding = false,
691 static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
692 {
693 #if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
694  return fp8x2_storage_t{
695  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
696  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
697 #else
698  ignore = rng;
699 
700  union
701  {
702  half2_t half_vec;
703  shortx2_t i16_vec;
704  fp8_storage_t i8val[4];
705  } val;
706 
707  constexpr shortx2_t i16x2val = {0, 0};
708  val.half_vec = v;
709 
710  if constexpr(saturate)
711  {
712  if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
713  {
714  val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
715  }
716  if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
717  {
718  val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 57344.0, -57344.0);
719  }
720  }
721 
722  val.i16_vec =
723  __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
724 
725  return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
726 #endif
727 }
728 
729 template <ck_fp8_interpretation_t interpret,
730  bool saturate,
731  bool stochastic_rounding = false,
734 static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
735 {
736  union
737  {
738  unsigned int i32val;
739  ushortx2_t bhalf_vec;
740  fp8_storage_t i8val[4];
741  } val;
742 
743  constexpr unsigned int i32val = 0;
744  val.bhalf_vec[0] = v;
745 
746  if constexpr(saturate)
747  {
748  if((val.i32val & 0x7FFF) != 0x7FFF)
749  {
750  val.bhalf_vec[0] =
751  ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
752  bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
753  16)); // convert to float and back
754  }
755  }
756 
757  val.i32val = __builtin_amdgcn_cvt_scalef32_sr_fp8_bf16(
758  i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0);
759 
760  return val.i8val[0];
761 }
762 
763 template <ck_fp8_interpretation_t interpret,
764  bool saturate,
765  bool stochastic_rounding = false,
768 static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
769 {
770  // there is no packed conversion with SR, so convert one element at a time
771  return fp8x2_storage_t{
772  cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
773  cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
774 }
775 
776 template <ck_fp8_interpretation_t interpret,
777  bool saturate,
778  bool stochastic_rounding = false,
781 static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
782 {
783  union
784  {
785  unsigned int i32val;
786  ushortx2_t bhalf_vec;
787  fp8_storage_t i8val[4];
788  } val;
789 
790  constexpr unsigned int i32val = 0;
791  val.bhalf_vec[0] = v;
792 
793  if constexpr(saturate)
794  {
795  if((val.i32val & 0x7FFF) != 0x7FFF)
796  {
797  val.bhalf_vec[0] = ushort(
798  (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
799  bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
800  16)); // convert to float and back
801  }
802  }
803 
804  val.i32val = __builtin_amdgcn_cvt_scalef32_sr_bf8_bf16(
805  i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0);
806 
807  return val.i8val[0];
808 }
809 
810 template <ck_fp8_interpretation_t interpret,
811  bool saturate,
812  bool stochastic_rounding = false,
815 static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
816 {
817  // there is no packed conversion with SR, so convert one element at a time
818  return fp8x2_storage_t{
819  cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
820  cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
821 }
822 
823 template <ck_fp8_interpretation_t interpret,
824  bool saturate,
825  bool stochastic_rounding = false,
828 static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
829 {
830  ignore = rng;
831 
832  union
833  {
834  unsigned int i32val;
835  ushortx2_t bhalf_vec;
836  shortx2_t i16_vec;
837  fp8_storage_t i8val[4];
838  } val;
839 
840  constexpr shortx2_t i16x2val = {0, 0};
841  val.bhalf_vec[0] = v;
842 
843  if constexpr(saturate)
844  {
845  if((val.i32val & 0x7FFF) != 0x7FFF)
846  {
847  val.bhalf_vec[0] =
848  ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
849  bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
850  16)); // convert to float and back
851  }
852  }
853 
854  val.i16_vec =
855  __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
856 
857  return val.i8val[0];
858 }
859 
860 template <ck_fp8_interpretation_t interpret,
861  bool saturate,
862  bool stochastic_rounding = false,
865 static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
866 {
867 #if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
868  return fp8x2_storage_t{
869  cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
870  cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
871 #else
872  ignore = rng;
873 
874  union
875  {
876  ushortx2_t bhalf_vec;
877  shortx2_t i16_vec;
878  fp8_storage_t i8val[4];
879  } val;
880 
881  constexpr shortx2_t i16x2val = {0, 0};
882  val.bhalf_vec = v;
883 
884  if constexpr(saturate)
885  {
886  if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
887  {
888  val.bhalf_vec[0] =
889  ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
890  bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
891  16)); // convert to float and back
892  }
893  if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
894  {
895  val.bhalf_vec[1] =
896  ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
897  bit_cast<float>(uint32_t{val.bhalf_vec[1]} << 16), 448.0, -448.0)) >>
898  16)); // convert to float and back
899  }
900  }
901 
902  val.i16_vec =
903  __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
904 
905  return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
906 #endif
907 }
908 
909 template <ck_fp8_interpretation_t interpret,
910  bool saturate,
911  bool stochastic_rounding = false,
914 static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
915 {
916  ignore = rng;
917 
918  union
919  {
920  unsigned int i32val;
921  ushortx2_t bhalf_vec;
922  shortx2_t i16_vec;
923  fp8_storage_t i8val[4];
924  } val;
925 
926  constexpr shortx2_t i16x2val = {0, 0};
927  val.bhalf_vec[0] = v;
928 
929  if constexpr(saturate)
930  {
931  if((val.i32val & 0x7FFF) != 0x7FFF)
932  {
933  val.bhalf_vec[0] = ushort(
934  (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
935  bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
936  16)); // convert to float and back
937  }
938  }
939 
940  val.i16_vec =
941  __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
942 
943  return val.i8val[0];
944 }
945 
946 template <ck_fp8_interpretation_t interpret,
947  bool saturate,
948  bool stochastic_rounding = false,
951 static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
952 {
953  ignore = rng;
954 
955  union
956  {
957  ushortx2_t bhalf_vec;
958  shortx2_t i16_vec;
959  fp8_storage_t i8val[4];
960  } val;
961 
962  constexpr shortx2_t i16x2val = {0, 0};
963  val.bhalf_vec = v;
964 
965  if constexpr(saturate)
966  {
967  if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
968  {
969  val.bhalf_vec[0] = ushort(
970  (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
971  bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
972  16)); // convert to float and back
973  }
974  if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
975  {
976  val.bhalf_vec[1] = ushort(
977  (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
978  bit_cast<float>(uint32_t{val.bhalf_vec[1]} << 16), 57344.0, -57344.0)) >>
979  16)); // convert to float and back
980  }
981  }
982 
983  val.i16_vec =
984  __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
985 
986  return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
987 }
988 #endif // defined(__gfx950__)
989 
990 #if CK_FP8_CVT_FAST_PATH
991 // The conversion function is from rocblas
992 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
993 template <ck_fp8_interpretation_t interpret, bool saturate, bool stochastic_rounding = false>
994 static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
995 {
996  fp8_storage_t i8data;
997  union
998  {
999  float fval;
1000  unsigned int i32val;
1001  unsigned char i8val[4]; // NOTE: not endian independent
1002  } val;
1003 
1004  unsigned int ival = 0;
1005  val.fval = v;
1006 
1007  if constexpr(saturate)
1008  {
1009  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
1010  {
1011  if((val.i32val & 0x7F800000) != 0x7F800000)
1012  {
1013  val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
1014  }
1015  }
1016  else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
1017  { // OCP type
1018  if((val.i32val & 0x7F800000) != 0x7F800000)
1019  {
1020  val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
1021  }
1022  }
1023  else
1024  {
1025  if((val.i32val & 0x7F800000) != 0x7F800000)
1026  {
1027  val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
1028  }
1029  }
1030  }
1031 
1032  if constexpr(stochastic_rounding)
1033  {
1034  ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
1036  ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
1037  : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
1038  val.i32val = ival;
1039  i8data = val.i8val[0]; // little endian
1040  }
1041  else
1042  { // RNE CVT
1043  ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
1045  ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
1046  : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
1047  val.fval,
1048  ival,
1049  false); // false -> WORD0
1050  val.i32val = ival;
1051  i8data = val.i8val[0];
1052  }
1053  return i8data;
1054 }
1055 
1056 template <ck_fp8_interpretation_t interpret, bool saturate, bool stochastic_rounding = false>
1057 static __device__ fp8x2_storage_t cast_to_f8_from_f32(float2_t v, unsigned int rng = 0)
1058 {
1059  if constexpr(stochastic_rounding)
1060  {
1061  // there is no packed conversion with SR, so convert one element at a time
1062  return fp8x2_storage_t{
1063  cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[0], rng),
1064  cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[1], rng)};
1065  }
1066  else
1067  {
1068  union
1069  {
1070  float fval;
1071  unsigned int i32val;
1072  unsigned char i8val[4];
1073  } val0, val1;
1074 
1075  val0.fval = v[0];
1076  val1.fval = v[1];
1077 
1078  unsigned int ival = 0;
1079 
1080  if constexpr(saturate)
1081  {
1082  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
1083  {
1084  if((val0.i32val & 0x7F800000) != 0x7F800000)
1085  {
1086  val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 240.0, -240.0);
1087  }
1088  if((val1.i32val & 0x7F800000) != 0x7F800000)
1089  {
1090  val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 240.0, -240.0);
1091  }
1092  }
1093  else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
1094  { // OCP type
1095  if((val0.i32val & 0x7F800000) != 0x7F800000)
1096  {
1097  val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 448.0, -448.0);
1098  }
1099  if((val1.i32val & 0x7F800000) != 0x7F800000)
1100  {
1101  val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 448.0, -448.0);
1102  }
1103  }
1104  else
1105  {
1106  if((val0.i32val & 0x7F800000) != 0x7F800000)
1107  {
1108  val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 57344.0, -57344.0);
1109  }
1110  if((val1.i32val & 0x7F800000) != 0x7F800000)
1111  {
1112  val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 57344.0, -57344.0);
1113  }
1114  }
1115  }
1116 
1117  // RNE CVT
1118  if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
1119  (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP))
1120  {
1121  ival = __builtin_amdgcn_cvt_pk_fp8_f32(val0.fval, val1.fval, ival, false);
1122  }
1123  else
1124  {
1125  ival = __builtin_amdgcn_cvt_pk_bf8_f32(val0.fval, val1.fval, ival, false);
1126  }
1127 
1128  val0.i32val = ival;
1129 
1130  return fp8x2_storage_t{val0.i8val[0], val0.i8val[1]};
1131  }
1132 }
1133 #endif // CK_FP8_CVT_FAST_PATH
1134 
1135 // The conversion function is from rocblas
1136 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
1137 // This has been modified to add double types conversion as well
1138 template <typename T, int wm, int we, bool is_fnuz, bool clip = false, bool stoch = false>
1139 __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rng = 0)
1140 {
1141  constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
1142  constexpr bool is_float = __hip_internal::is_same<T, float>::value;
1143  constexpr bool is_double = __hip_internal::is_same<T, double>::value;
1144  static_assert(is_half || is_float || is_double,
1145  "Only half, float and double can be cast to f8");
1146 
1147  constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
1148 
1149  using T_bitwise = typename ck::conditional_t<
1150  sizeof(T) == 2,
1151  unsigned short int,
1152  typename ck::conditional_t<sizeof(T) == 4, unsigned int, unsigned long long>>;
1153  T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
1154 
1155  unsigned long long x{x_bitwise};
1156 
1157  unsigned long long head, mantissa;
1158  int exponent, bias;
1159  unsigned int sign;
1160  unsigned long long fInf, mask;
1161 
1162  if constexpr(sizeof(T) == 8)
1163  {
1164  head = x & 0xFFF0000000000000ull;
1165  mantissa = x & 0xFFFFFFFFFFFFFull;
1166  exponent = (head >> 52) & 0x7FF;
1167  sign = head >> 63;
1168  bias = 1023;
1169  fInf = 0x7FF0000000000000ull;
1170  mask = 0x7FFFFFFFFFFFFFFFull;
1171  }
1172  else if constexpr(sizeof(T) == 4)
1173  {
1174  head = x & 0xFF800000;
1175  mantissa = x & 0x7FFFFF;
1176  exponent = (head >> 23) & 0xFF;
1177  sign = head >> 31;
1178  bias = 127;
1179  fInf = 0x7F800000;
1180  mask = 0x7FFFFFFF;
1181  }
1182  else
1183  {
1184  head = x & 0xFC00;
1185  mantissa = x & 0x3FF;
1186  exponent = (head >> 10) & 0x1F;
1187  sign = head >> 15;
1188  bias = 15;
1189  fInf = 0x7C00;
1190  mask = 0x7FFF;
1191  }
1192  unsigned int signed_inf = 0;
1193  unsigned int nan = 0;
1194  if constexpr(is_fnuz)
1195  {
1196  signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
1197  nan = 0x80;
1198  }
1199  else
1200  {
1201  if constexpr(we == 4)
1202  { // e4m3
1203  signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
1204  }
1205  else
1206  { // e5m2
1207  signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
1208  }
1209  nan = (sign << 7) + 0x7f;
1210  }
1211  // Max values
1212  unsigned long long ifmax = 0;
1213  if constexpr(sizeof(T) == 8)
1214  {
1215  if constexpr(we == 5)
1216  { // 57344
1217  ifmax = 0x40EC000000000000ull;
1218  }
1219  else
1220  {
1221  if constexpr(is_fnuz)
1222  { // 240
1223  ifmax = 0x406E000000000000ull;
1224  }
1225  else
1226  { // 448
1227  ifmax = 0x407C000000000000ull;
1228  }
1229  }
1230  }
1231  else if(sizeof(T) == 4)
1232  {
1233  if constexpr(we == 5)
1234  {
1235  ifmax = 0x47600000;
1236  }
1237  else
1238  {
1239  if constexpr(is_fnuz)
1240  {
1241  ifmax = 0x43700000;
1242  }
1243  else
1244  {
1245  ifmax = 0x43E00000;
1246  }
1247  }
1248  }
1249  else
1250  {
1251  if constexpr(we == 5)
1252  {
1253  ifmax = 0x7B00;
1254  }
1255  else
1256  {
1257  if constexpr(is_fnuz)
1258  {
1259  ifmax = 0x5B80;
1260  }
1261  else
1262  {
1263  ifmax = 0x5F00;
1264  }
1265  }
1266  }
1267  // Deal with inf and NaNs
1268  if((x & fInf) == fInf)
1269  {
1270  if constexpr(is_fnuz)
1271  return signed_inf;
1272 
1273  return mantissa != 0 ? nan : signed_inf;
1274  }
1275 
1276  if((x & mask) > ifmax)
1277  {
1278  return signed_inf;
1279  }
1280 
1281  if(x == 0)
1282  {
1283  return 0;
1284  }
1285 
1286  // First need to check if it is normal or denorm as there is a difference of
1287  // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
1288  // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
1289  // to mantissa and truncate. And for RNE, no need to add rng. Then probably
1290  // need to check whether there is carry and adjust exponent and mantissa again
1291 
1292  // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
1293  // bits
1294  const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
1295  const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
1296  // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
1297  // f8_exponent is the converted f8 exponent with bias encoding
1298  // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
1299  // the difference needs to be adjusted and mantissa shifted
1300  int act_exponent, f8_exponent, exponent_diff;
1301 
1302  if(exponent == 0)
1303  { // fp32/fp16 is in denormal.
1304  /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
1305  mostly concern fp16 here. In this case, f8 is usually in denormal. But there
1306  could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
1307  exponent bias 16. It means that there are some numbers in fp16 denormal but they
1308  are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
1309  where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
1310  (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
1311  act_exponent = exponent - bias + 1;
1312  exponent_diff = f8_denormal_act_exponent -
1313  act_exponent; // actual exponent is exponent-bias+1 as it is denormal
1314  }
1315  else
1316  { // fp32/fp16 is normal with implicit 1
1317  act_exponent = exponent - bias;
1318  if(act_exponent <= f8_denormal_act_exponent)
1319  {
1320  /* This is the case where fp32/fp16 is normal but it is in f8 denormal
1321  range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
1322  actual exponent is -7, it is actually larger due to the implicit 1,
1323  Therefore it needs to be adjust to -6 and mantissa shift right by 1.
1324  So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
1325  exponent_diff = f8_denormal_act_exponent - act_exponent;
1326  }
1327  else
1328  { // both fp32/fp16 and f8 are in normal range
1329  exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
1330  // for this case, act_exponent could be larger. Just
1331  // that it does not need shift mantissa
1332  }
1333  mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa
1334  }
1335 
1336  bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
1337  (1ull << (mfmt - wm + exponent_diff - 1));
1338  /* This part is a bit tricky. The judgment of whether it is a tie needs to be
1339  done before we shift right as shift right could rip off some residual part and
1340  make something not midpoint look like midpoint. For example, the fp16 number
1341  0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
1342  by 4 bits, it would look like midpoint.
1343  */
1344 
1345  if(exponent_diff > 0)
1346  mantissa >>= exponent_diff;
1347  else if(exponent_diff == -1)
1348  mantissa <<= -exponent_diff;
1349  bool implicit_one = mantissa & (1ull << mfmt);
1350  // if there is no implicit 1, it means the f8 is denormal and need to adjust
1351  // to denorm exponent
1352  f8_exponent =
1353  (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
1354 
1355  // Now we have the exponent and mantissa adjusted
1356  unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
1357  bool odd =
1358  mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1
1359  mantissa +=
1360  (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
1361 
1362  // Now we deal with overflow
1363  if(f8_exponent == 0)
1364  {
1365  if((1ull << mfmt) & mantissa)
1366  {
1367  f8_exponent = 1; // denormal overflow to become normal, promote exponent
1368  }
1369  }
1370  else
1371  {
1372  if((1ull << (mfmt + 1)) & mantissa)
1373  {
1374  mantissa >>= 1;
1375  f8_exponent++;
1376  }
1377  }
1378 
1379  mantissa >>= (mfmt - wm);
1380 
1381  // above range: quantize to maximum possible float of the same sign
1382  const int max_exp = (1 << we) - 1;
1383  if(f8_exponent > max_exp)
1384  {
1385  if constexpr(clip)
1386  {
1387  mantissa = (1 << wm) - 1;
1388  f8_exponent = max_exp;
1389  }
1390  else
1391  {
1392  return signed_inf;
1393  }
1394  }
1395 
1396  if(f8_exponent == 0 && mantissa == 0)
1397  return is_fnuz ? 0 : (sign << 7);
1398  mantissa &= (1 << wm) - 1;
1399  return (sign << 7) | (f8_exponent << wm) | mantissa;
1400 }
1401 
1411 template <ck_fp8_interpretation_t interp,
1413  bool stochastic_rounding = false>
1414 #if CK_FP8_CVT_FAST_PATH
1415 __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
1416 {
1417  __is_interpret_supported(interp);
1418  uint32_t rng = 0;
1419  if constexpr(stochastic_rounding)
1420  {
1421 #if defined(__gfx950__)
1422  // use HW clock for stochastic input multiply by incremented thread id
1423  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1424  (get_thread_global_1d_id() + 1));
1425 #else
1426  constexpr int seed = 1254739;
1427 #ifndef CK_CODE_GEN_RTC
1428  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
1429 #else
1430  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
1431 #endif // #ifndef CK_CODE_GEN_RTC
1432 #endif // #if defined(__gfx950__)
1433  }
1434  return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1435  f, rng);
1436 #else
1437 #if CK_USE_OCP_FP8
1438 __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
1439 {
1440 #else
1441 __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
1442 {
1443 #endif
1444  uint32_t rng = 0;
1445  if constexpr(stochastic_rounding)
1446  {
1447 #if defined(__gfx950__)
1448  // use HW clock for stochastic input multiply by incremented thread id
1449  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1450  (get_thread_global_1d_id() + 1));
1451 #else
1452  constexpr int seed = 1254739;
1453 #ifndef CK_CODE_GEN_RTC
1454  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
1455 #else
1456  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
1457 #endif // #ifndef CK_CODE_GEN_RTC
1458 #endif // #if defined(__gfx950__)
1459  }
1460 
1461  if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
1462  {
1463  return cast_to_f8<float,
1464  3,
1465  4,
1466  true,
1468  stochastic_rounding>(f, rng);
1469  }
1470  else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_FNUZ)
1471  {
1472  return cast_to_f8<float,
1473  2,
1474  5,
1475  true,
1477  stochastic_rounding>(f, rng);
1478  }
1479  else if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
1480  {
1481  return cast_to_f8<float,
1482  3,
1483  4,
1484  false,
1486  stochastic_rounding>(f, rng);
1487  }
1488  else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
1489  {
1490  return cast_to_f8<float,
1491  2,
1492  5,
1493  false,
1495  stochastic_rounding>(f, rng);
1496  }
1497  else
1498  {
1499  __hip_assert(false && "FP8 type is not supported by current target device");
1500  return 0;
1501  }
1502 #endif // CK_FP8_CVT_FAST_PATH
1503 }
1504 
1514 template <ck_fp8_interpretation_t interp,
1516  bool stochastic_rounding = false>
1517 #if CK_FP8_CVT_FAST_PATH
1518 __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
1519 {
1520  __is_interpret_supported(interp);
1521  uint32_t rng = 0;
1522  if constexpr(stochastic_rounding)
1523  {
1524 #if defined(__gfx950__)
1525  // use HW clock for stochastic input multiply by incremented thread id
1526  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1527  (get_thread_global_1d_id() + 1));
1528 #else
1529  constexpr int seed = 1254739;
1530 #ifndef CK_CODE_GEN_RTC
1531  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
1532 #else
1533  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f[0]);
1534 #endif // #ifndef CK_CODE_GEN_RTC
1535 #endif // #if defined(__gfx950__)
1536  }
1537  return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1538  f, rng);
1539 #else
1540 #if CK_USE_OCP_FP8
1541 __host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
1542 {
1543 #else
1544 __host__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
1545 {
1546 #endif // CK_USE_OCP_FP8
1547  return fp8x2_storage_t{cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[0]),
1548  cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[1])};
1549 #endif // CK_FP8_CVT_FAST_PATH
1550 }
1551 
1561 template <ck_fp8_interpretation_t interp,
1563  bool stochastic_rounding = false>
1564 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1565 __host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
1566 #else
1567 __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
1568 #endif
1569 {
1570  {
1571  __is_interpret_supported(interp);
1572  uint32_t rng = 0;
1573  if constexpr(stochastic_rounding)
1574  {
1575 #if defined(__gfx950__)
1576  // use HW clock for stochastic input multiply by incremented thread id
1577  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1578  (get_thread_global_1d_id() + 1));
1579 #else
1580  constexpr int seed = 1254739;
1581 #ifndef CK_CODE_GEN_RTC
1582  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
1583 #else
1584  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
1585 #endif // #ifndef CK_CODE_GEN_RTC
1586 #endif // #if defined(__gfx950__)
1587  }
1588 #if defined(__gfx950__)
1589  return cast_to_f8_from_f16<interp,
1591  stochastic_rounding>(x, rng);
1592 #else
1593  ignore = rng;
1594  return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1595  static_cast<float>(x));
1596 #endif // defined(__gfx950__)
1597  }
1598 }
1599 
1609 template <ck_fp8_interpretation_t interp,
1611  bool stochastic_rounding = false>
1612 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1613 __host__ __device__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
1614 #else
1615 __host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
1616 #endif
1617 {
1618  {
1619  __is_interpret_supported(interp);
1620  uint32_t rng = 0;
1621  if constexpr(stochastic_rounding)
1622  {
1623 #if defined(__gfx950__)
1624  // use HW clock for stochastic input multiply by incremented thread id
1625  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1626  (get_thread_global_1d_id() + 1));
1627 #else
1628  constexpr int seed = 1254739;
1629 #ifndef CK_CODE_GEN_RTC
1630  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
1631 #else
1632  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
1633 #endif // #ifndef CK_CODE_GEN_RTC
1634 #endif // #if defined(__gfx950__)
1635  }
1636 #if defined(__gfx950__)
1637  return cast_to_f8_from_f16<interp,
1639  stochastic_rounding>(x, rng);
1640 #else
1641  ignore = rng;
1642  return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1643  float2_t{static_cast<float>(x[0]), static_cast<float>(x[1])});
1644 #endif // defined(__gfx950__)
1645  }
1646 }
1647 
1657 template <ck_fp8_interpretation_t interp,
1659  bool stochastic_rounding = false>
1660 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1661 __host__ __device__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
1662 #else
1663 __host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
1664 #endif
1665 {
1666  {
1667  __is_interpret_supported(interp);
1668  uint32_t rng = 0;
1669  if constexpr(stochastic_rounding)
1670  {
1671 #if defined(__gfx950__)
1672  // use HW clock for stochastic input multiply by incremented thread id
1673  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1674  (get_thread_global_1d_id() + 1));
1675 #else
1676  constexpr int seed = 1254739;
1677 #ifndef CK_CODE_GEN_RTC
1678  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
1679  static_cast<float>(x));
1680 #else
1681  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), static_cast<float>(x));
1682 #endif // #ifndef CK_CODE_GEN_RTC
1683 #endif // #if defined(__gfx950__)
1684  }
1685 #if defined(__gfx950__)
1686  return cast_to_f8_from_bf16<interp,
1688  stochastic_rounding>(x, rng);
1689 #else
1690  ignore = rng;
1691  return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1692  bit_cast<float>(uint32_t{x} << 16)); // convert value to float
1693 #endif // defined(__gfx950__)
1694  }
1695 }
1696 
1706 template <ck_fp8_interpretation_t interp,
1708  bool stochastic_rounding = false>
1709 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1710 __host__ __device__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
1711 #else
1712 __host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
1713 #endif
1714 {
1715 #if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
1716  return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1717  float2_t{bit_cast<float>(uint32_t{x[0]} << 16),
1718  bit_cast<float>(uint32_t{x[1]} << 16)}); // convert values to float
1719 #else // CK_WORKAROUND_BF16_TO_FP8_CONVERSION
1720  {
1721  __is_interpret_supported(interp);
1722  uint32_t rng = 0;
1723  if constexpr(stochastic_rounding)
1724  {
1725 #if defined(__gfx950__)
1726  // use HW clock for stochastic input multiply by incremented thread id
1727  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1728  (get_thread_global_1d_id() + 1));
1729 #else
1730  constexpr int seed = 1254739;
1731 #ifndef CK_CODE_GEN_RTC
1732  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
1733  static_cast<float>(x[0]));
1734 #else
1735  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x),
1736  static_cast<float>(x[0]));
1737 #endif // #ifndef CK_CODE_GEN_RTC
1738 #endif // #if defined(__gfx950__)
1739  }
1740 #if defined(__gfx950__)
1741  return cast_to_f8_from_bf16<interp,
1743  stochastic_rounding>(x, rng);
1744 #else
1745  ignore = rng;
1746  return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1747  float2_t{bit_cast<float>(uint32_t{x[0]} << 16),
1748  bit_cast<float>(uint32_t{x[1]} << 16)}); // convert values to float
1749 #endif // defined(__gfx950__)
1750  }
1751 #endif // CK_WORKAROUND_BF16_TO_FP8_CONVERSION
1752 }
1753 
1754 } // namespace fp8_impl
1755 
1756 #if CK_USE_OCP_FP8
1757 using f8_t = f8_ocp_t;
1758 using bf8_t = bf8_ocp_t;
1759 #define CK_FP8_TYPE_FNUZ 0
1760 #define CK_FP8_TYPE_OCP 1
1761 #else
1762 using f8_t = f8_fnuz_t;
1764 #define CK_FP8_TYPE_FNUZ 1
1765 #define CK_FP8_TYPE_OCP 0
1766 #endif
1767 
1768 } // namespace ck
#define __fp8_impl_assert_fnuz_support(interp)
Definition: amd_ck_fp8.hpp:464
#define __fp8_impl_assert_ocp_support(interp)
Definition: amd_ck_fp8.hpp:456
ushort ushortx2_t
Definition: amd_ck_fp8.hpp:90
short shortx2_t
Definition: amd_ck_fp8.hpp:91
float float2_t
Definition: amd_ck_fp8.hpp:92
fp8_storage_t fp8x2_storage_t
Definition: amd_ck_fp8.hpp:88
_Float16 half2_t
Definition: amd_ck_fp8.hpp:89
Definition: ck.hpp:268
__host__ constexpr __device__ Y bit_cast(const X &x)
Definition: type.hpp:306
bf8_fnuz_t bf8_t
Definition: amd_ck_fp8.hpp:1763
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1762
ck_fp8_interpretation_t
Describes FP8 interpretation.
Definition: amd_ck_fp8.hpp:70
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:43
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
ck_saturation_t
Describes saturation behavior.
Definition: amd_ck_fp8.hpp:81
unsigned char fp8_storage_t
Definition: amd_ck_fp8.hpp:61
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
_W64 unsigned int uintptr_t
Definition: stdint.h:165
unsigned int uint32_t
Definition: stdint.h:126
Definition: amd_ck_fp8.hpp:49
__host__ constexpr __device__ bf8_fnuz_t(data_type in_data)
Definition: amd_ck_fp8.hpp:52
data_type m_data
Definition: amd_ck_fp8.hpp:51
unsigned char data_type
Definition: amd_ck_fp8.hpp:50
__host__ constexpr __device__ bf8_fnuz_t()=default
__host__ __device__ constexpr bool operator==(bf8_fnuz_t other) const
Definition: amd_ck_fp8.hpp:54
Definition: amd_ck_fp8.hpp:369
__host__ constexpr __device__ bool operator==(const bf8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:380
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:370
data_type data
Definition: amd_ck_fp8.hpp:371
Definition: amd_ck_fp8.hpp:36
data_type m_data
Definition: amd_ck_fp8.hpp:38
__host__ constexpr __device__ f8_fnuz_t()=default
__host__ __device__ constexpr bool operator==(f8_fnuz_t other) const
Definition: amd_ck_fp8.hpp:41
unsigned char data_type
Definition: amd_ck_fp8.hpp:37
__host__ constexpr __device__ f8_fnuz_t(data_type in_data)
Definition: amd_ck_fp8.hpp:39
Definition: amd_ck_fp8.hpp:323
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:324
data_type data
Definition: amd_ck_fp8.hpp:325
__host__ constexpr __device__ bool operator==(const f8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:334