/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_ck_fp8.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_ck_fp8.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_ck_fp8.hpp Source File
amd_ck_fp8.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/ck.hpp"
8 #include "ck/utility/get_id.hpp"
11 #include "ck/utility/type.hpp"
12 
13 #ifndef CK_USE_FNUZ_FP8
14 #define CK_USE_FNUZ_FP8 0
15 #endif
16 
17 #ifndef CK_USE_OCP_FP8
18 #define CK_USE_OCP_FP8 0
19 #endif
20 
21 #if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \
22  __HIP_DEVICE_COMPILE__
23 #define CK_FP8_CVT_FAST_PATH 1
24 #else
25 #define CK_FP8_CVT_FAST_PATH 0
26 #endif
27 
28 #if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
29 #define CK_OCP_FP8_CVT_FAST_PATH 1
30 #else
31 #define CK_OCP_FP8_CVT_FAST_PATH 0
32 #endif
33 
34 namespace ck {
35 
36 using f8_fnuz_t = _BitInt(8);
37 using bf8_fnuz_t = unsigned _BitInt(8);
38 
39 typedef unsigned char fp8_storage_t;
40 
45 {
46  CK_E4M3_OCP = 0, // OCP E4M3
47  CK_E5M2_OCP = 1, // OCP E5M2
48  CK_E4M3_FNUZ = 2, // FP8
49  CK_E5M2_FNUZ = 3, // BF8
50 };
51 
55 enum class ck_saturation_t
56 {
57  CK_NOSAT = 0, // No saturation - replace with NaN or Inf
58  CK_SATFINITE = 1, // Saturate to finite
59 };
60 
61 namespace fp8_impl {
62 
63 typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2)));
64 typedef _Float16 half2_t __attribute__((ext_vector_type(2)));
65 typedef ushort ushortx2_t __attribute__((ext_vector_type(2)));
66 typedef short shortx2_t __attribute__((ext_vector_type(2)));
67 typedef float float2_t __attribute__((ext_vector_type(2)));
68 
69 __host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a)
70 {
71  return static_cast<unsigned char>(a) == 0x80;
72 }
73 __host__ __device__ static inline constexpr bool fnuz_bf8_is_nan(bf8_fnuz_t a)
74 {
75  return static_cast<unsigned char>(a) == 0x80;
76 }
77 
78 __host__ __device__ static inline constexpr bool ocp_f8_is_nan(fp8_storage_t a)
79 {
80  return (a & 0x7f) == 0x7f;
81 }
82 __host__ __device__ static inline constexpr bool ocp_bf8_is_nan(fp8_storage_t a)
83 {
84  return (a & 0x7f) > 0x7c;
85 }
86 
87 // The conversion function is from rocblas
88 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
89 // This has been modified to handle double types as well
90 template <typename T, int wm, int we, bool is_fnuz, bool clip = false>
91 __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
92 {
93  constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
94  constexpr bool is_float = __hip_internal::is_same<T, float>::value;
95  constexpr bool is_double = __hip_internal::is_same<T, double>::value;
96  static_assert(is_half || is_float || is_double, "only half, float and double are supported");
97 
98  constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
99  constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
100 
101  T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
102  if constexpr(is_half)
103  {
104  const unsigned short int ihInf = 0x7C00;
105  const unsigned short int ihNegInf = 0xFC00;
106  const unsigned short int ihNaN = 0x7C01;
107  const unsigned short int ihNeg0 = 0x8000;
108  /* Max number in e5m2 57344*/
109  const unsigned short int ifmax = 0x7B00;
110  const unsigned short int ifmin = 0xFB00;
111 
112  fInf = bit_cast<_Float16>(ihInf);
113  fNegInf = bit_cast<_Float16>(ihNegInf);
114  fNaN = bit_cast<_Float16>(ihNaN);
115  fNeg0 = bit_cast<_Float16>(ihNeg0);
116  fmax = bit_cast<_Float16>(ifmax);
117  fmin = bit_cast<_Float16>(ifmin);
118  }
119  else if constexpr(is_float)
120  {
121  const unsigned int ifInf = 0x7F800000;
122  const unsigned int ifNegInf = 0xFF800000;
123  const unsigned int ifNaN = 0x7F800001;
124  const unsigned int ifNeg0 = 0x80000000;
125  /* Max number in e5m2 57344*/
126  const unsigned int ifmax = 0x47600000;
127  const unsigned int ifmin = 0xC7600000;
128 
129  fInf = bit_cast<float>(ifInf);
130  fNegInf = bit_cast<float>(ifNegInf);
131  fNaN = bit_cast<float>(ifNaN);
132  fNeg0 = bit_cast<float>(ifNeg0);
133  fmax = bit_cast<float>(ifmax);
134  fmin = bit_cast<float>(ifmin);
135  }
136  else if constexpr(is_double)
137  {
138  const unsigned long long ifInf = 0x7FF0000000000000ull;
139  const unsigned long long ifNegInf = 0xFFF0000000000000ull;
140  const unsigned long long ifNaN = 0x7FF0000000000001ull;
141  const unsigned long long ifNeg0 = 0x8000000000000000ull;
142  /* Max number in e5m2 57344*/
143  const unsigned long long ifmax = 0x40EC000000000000ull;
144  const unsigned long long ifmin = 0xC0EC000000000000ull;
145 
146  fInf = bit_cast<double>(ifInf);
147  fNegInf = bit_cast<double>(ifNegInf);
148  fNaN = bit_cast<double>(ifNaN);
149  fNeg0 = bit_cast<double>(ifNeg0);
150  fmax = bit_cast<double>(ifmax);
151  fmin = bit_cast<double>(ifmin);
152  }
153 
154  if(x == 0)
155  {
156  return 0;
157  }
158 
159  unsigned long long sign = x >> 7;
160  unsigned long long mantissa = x & ((1 << wm) - 1);
161  int exponent = (x & 0x7F) >> wm;
162  if constexpr(is_fnuz)
163  {
164  if(x == 0x80)
165  {
166  return fNaN;
167  }
168  }
169  else
170  {
171  if(x == 0x80)
172  {
173  return fNeg0;
174  }
175  if constexpr(we == 4)
176  { // e4m3
177  if((x & 0x7F) == 0x7F)
178  {
179  return fNaN;
180  }
181  }
182  else if((x & 0x7C) == 0x7C)
183  { // e5m2
184  if((x & 0x3) == 0)
185  {
186  if constexpr(clip)
187  {
188  return sign ? fmin : fmax;
189  }
190  return sign ? fNegInf : fInf;
191  }
192  return fNaN;
193  }
194  }
195 
196  typename ck::conditional_t<
197  sizeof(T) == 2,
198  unsigned short int,
199  typename ck::conditional_t<sizeof(T) == 4, unsigned int, unsigned long long>>
200  retval;
201 
202  if constexpr(we == 5 && is_half && !is_fnuz)
203  {
204  retval = x << 8;
205  return bit_cast<T>(retval);
206  }
207 
208  const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
209 
210  // subnormal input
211  if(exponent == 0)
212  {
213 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
214  // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
215  int sh = 1 + __clz(mantissa) - (32 - wm);
216 #else
217  int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
218 #endif
219  mantissa <<= sh;
220  exponent += 1 - sh;
221  mantissa &= ((1ull << wm) - 1);
222  }
223  exponent += exp_low_cutoff - 1;
224  mantissa <<= wmo - wm;
225 
226  // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
227  if(exponent <= 0)
228  {
229  mantissa |= 1 << wmo;
230  mantissa >>= 1 - exponent;
231  exponent = 0;
232  }
233 
234  if constexpr(sizeof(T) == 2)
235  retval = (sign << 15) | (exponent << 10) | mantissa;
236  else if constexpr(sizeof(T) == 4)
237  retval = (sign << 31) | (exponent << 23) | mantissa;
238  else
239  retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
240 
241  return bit_cast<T>(retval);
242 }
243 
244 #if CK_FP8_CVT_FAST_PATH
245 template <ck_fp8_interpretation_t interpret>
246 static __host__ __device__ float cast_to_f32_from_f8(fp8_storage_t v)
247 {
248  union
249  {
250  unsigned int i32val;
251  unsigned char i8val[4];
252  } val;
253  val.i8val[0] = v;
254 
255  static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
259  "Only FNUZ and OCP interpretations are supported");
260 
261  if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
263  {
264  return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
265  }
266  else
267  {
268  return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
269  }
270 }
271 
272 template <ck_fp8_interpretation_t interpret>
273 static __device__ float2_t cast_to_f32_from_f8(fp8x2_storage_t v)
274 {
275  const auto i16val = bit_cast<uint16_t>(v);
276 
277  static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
281  "Only FNUZ and OCP interpretations are supported");
282 
283  if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
285  {
286  return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false);
287  }
288  else
289  {
290  return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false);
291  }
292 }
293 #endif
294 
295 } // namespace fp8_impl
296 
297 struct f8_ocp_t
298 {
301 
302  static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
303  static constexpr ck_fp8_interpretation_t default_interpret =
305 
306  static constexpr unsigned int we = 4; // exponent width
307  static constexpr unsigned int wm = 3; // mantissa width
308 
309  __host__ __device__ constexpr bool operator==(const f8_ocp_t& other) const
310  {
311  return (data == other.data) && (fp8_impl::ocp_f8_is_nan(data) == false); // NaN != NaN
312  }
313 
314 #if CK_USE_OCP_FP8
315  __host__ __device__ explicit operator float() const
316 #else
317  __host__ explicit operator float() const
318 #endif
319  {
320 #if CK_OCP_FP8_CVT_FAST_PATH
321  return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
322 #else
323  return fp8_impl::cast_from_f8<float, wm, we, false>(
324  this->data); // XXX: clip==false must be consistent with operator _Float16
325 #endif
326  }
327 
328 #if CK_USE_OCP_FP8
329  __host__ __device__ explicit operator _Float16() const
330 #else
331  __host__ explicit operator _Float16() const
332 #endif
333  {
334 #if CK_OCP_FP8_CVT_FAST_PATH
335  return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
336 #else
337  return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
338  this->data); // XXX: clip==false must be consistent with operator float
339 #endif
340  }
341 };
342 
343 struct bf8_ocp_t
344 {
347 
348  static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
349  static constexpr ck_fp8_interpretation_t default_interpret =
351 
352  static constexpr unsigned int we = 5; // exponent width
353  static constexpr unsigned int wm = 2; // mantissa width
354 
355  __host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const
356  {
357  return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN
358  }
359 
360 #if CK_USE_OCP_FP8
361  __host__ __device__ explicit operator float() const
362 
363 #else
364  __host__ explicit operator float() const
365 #endif
366  {
367 #if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
368  return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
369 #else
370  return fp8_impl::cast_from_f8<float, wm, we, false>(
371  this->data); // XXX: clip==false must be consistent with operator _Float16
372 #endif
373  }
374 
375 #if CK_USE_OCP_FP8
376  __host__ __device__ explicit operator _Float16() const
377 #else
378  __host__ explicit operator _Float16() const
379 #endif
380  {
381 #if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
382  return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
383 #else
384  return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
385  this->data); // XXX: clip==false must be consistent with operator float
386 #endif
387  }
388 };
389 
390 template <typename T>
391 __host__ __device__ static inline constexpr bool fp8_is_nan(T);
392 
393 template <>
394 __host__ __device__ inline constexpr bool fp8_is_nan(f8_ocp_t a)
395 {
396  return fp8_impl::ocp_f8_is_nan(a.data);
397 }
398 template <>
399 __host__ __device__ inline constexpr bool fp8_is_nan(bf8_ocp_t a)
400 {
401  return fp8_impl::ocp_bf8_is_nan(a.data);
402 }
403 template <>
404 __host__ __device__ inline constexpr bool fp8_is_nan(f8_fnuz_t a)
405 {
406  return fp8_impl::fnuz_f8_is_nan(a);
407 }
408 template <>
409 __host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a)
410 {
411  return fp8_impl::fnuz_bf8_is_nan(a);
412 }
413 
414 template <typename T,
415  ck::enable_if_t<is_same_v<T, bf8_ocp_t> || is_same_v<T, f8_ocp_t> ||
416  is_same_v<T, bf8_fnuz_t> || is_same_v<T, f8_fnuz_t>,
417  bool> = true>
418 __host__ __device__ static inline constexpr bool fp8_is_inf(T)
419 {
420  return false;
421 }
422 template <>
423 __host__ __device__ inline constexpr bool fp8_is_inf(bf8_ocp_t a)
424 {
425  return (a.data & 0x7f) == 0x7c;
426 }
427 
428 namespace fp8_impl {
429 
430 // Assertions to check for supported conversion types
431 #define __fp8_impl_assert_ocp_support(interp) \
432  { \
433  if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
434  interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
435  { \
436  __hip_assert(false && "type is unsupported by current target device"); \
437  } \
438  }
439 #define __fp8_impl_assert_fnuz_support(interp) \
440  { \
441  if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
442  interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
443  { \
444  __hip_assert(false && "type is unsupported by current target device"); \
445  } \
446  }
447 
448 __host__ __device__ static inline void
449 __is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp)
450 {
451 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
452 #if CK_USE_OCP_FP8
454 #endif
455 #if CK_USE_FNUZ_FP8
457 #endif
458 #endif
459 }
460 
461 #if defined(__gfx950__)
462 template <ck_fp8_interpretation_t interpret,
463  bool saturate,
464  bool stochastic_rounding = false,
467 static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
468 {
469  union
470  {
471  unsigned int i32val;
472  half2_t half_vec;
473  fp8_storage_t i8val[4];
474  } val;
475 
476  constexpr unsigned int i32val = 0;
477  val.half_vec[0] = v;
478 
479  if constexpr(saturate)
480  {
481  if((val.i32val & 0x7FFF) != 0x7FFF)
482  {
483  val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
484  }
485  }
486 
487  val.i32val =
488  __builtin_amdgcn_cvt_scalef32_sr_fp8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0);
489 
490  return val.i8val[0];
491 }
492 
493 template <ck_fp8_interpretation_t interpret,
494  bool saturate,
495  bool stochastic_rounding = false,
498 static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
499 {
500  // there is no packed conversion with SR, so convert one element at a time
501  return fp8x2_storage_t{
502  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
503  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
504 }
505 
506 template <ck_fp8_interpretation_t interpret,
507  bool saturate,
508  bool stochastic_rounding = false,
511 static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
512 {
513  union
514  {
515  unsigned int i32val;
516  half2_t half_vec;
517  fp8_storage_t i8val[4];
518  } val;
519 
520  constexpr unsigned int i32val = 0;
521  val.half_vec[0] = v;
522 
523  if constexpr(saturate)
524  {
525  if((val.i32val & 0x7FFF) != 0x7FFF)
526  {
527  val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
528  }
529  }
530 
531  val.i32val =
532  __builtin_amdgcn_cvt_scalef32_sr_bf8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0);
533 
534  return val.i8val[0];
535 }
536 
537 template <ck_fp8_interpretation_t interpret,
538  bool saturate,
539  bool stochastic_rounding = false,
542 static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
543 {
544  // there is no packed conversion with SR, so convert one element at a time
545  return fp8x2_storage_t{
546  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
547  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
548 }
549 
550 template <ck_fp8_interpretation_t interpret,
551  bool saturate,
552  bool stochastic_rounding = false,
555 static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
556 {
557  ignore = rng;
558 
559  union
560  {
561  unsigned int i32val;
562  half2_t half_vec;
563  shortx2_t i16_vec;
564  fp8_storage_t i8val[4];
565  } val;
566 
567  constexpr shortx2_t i16x2val = {0, 0};
568  val.half_vec[0] = v;
569 
570  if constexpr(saturate)
571  {
572  if((val.i32val & 0x7FFF) != 0x7FFF)
573  {
574  val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
575  }
576  }
577 
578  val.i16_vec =
579  __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
580 
581  return val.i8val[0];
582 }
583 
584 template <ck_fp8_interpretation_t interpret,
585  bool saturate,
586  bool stochastic_rounding = false,
589 static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
590 {
591 #if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
592  return fp8x2_storage_t{
593  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
594  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
595 #else
596  ignore = rng;
597 
598  union
599  {
600  half2_t half_vec;
601  shortx2_t i16_vec;
602  fp8_storage_t i8val[4];
603  } val;
604 
605  constexpr shortx2_t i16x2val = {0, 0};
606  val.half_vec = v;
607 
608  if constexpr(saturate)
609  {
610  if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
611  {
612  val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
613  }
614  if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
615  {
616  val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 448.0, -448.0);
617  }
618  }
619 
620  val.i16_vec =
621  __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
622 
623  return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
624 #endif
625 }
626 
627 template <ck_fp8_interpretation_t interpret,
628  bool saturate,
629  bool stochastic_rounding = false,
632 static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
633 {
634  ignore = rng;
635 
636  union
637  {
638  unsigned int i32val;
639  half2_t half_vec;
640  shortx2_t i16_vec;
641  fp8_storage_t i8val[4];
642  } val;
643 
644  constexpr shortx2_t i16x2val = {0, 0};
645  val.half_vec[0] = v;
646 
647  if constexpr(saturate)
648  {
649  if((val.i32val & 0x7FFF) != 0x7FFF)
650  {
651  val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
652  }
653  }
654 
655  val.half_vec =
656  __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
657 
658  return val.i8val[0];
659 }
660 
661 template <ck_fp8_interpretation_t interpret,
662  bool saturate,
663  bool stochastic_rounding = false,
666 static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
667 {
668 #if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
669  return fp8x2_storage_t{
670  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
671  cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
672 #else
673  ignore = rng;
674 
675  union
676  {
677  half2_t half_vec;
678  shortx2_t i16_vec;
679  fp8_storage_t i8val[4];
680  } val;
681 
682  constexpr shortx2_t i16x2val = {0, 0};
683  val.half_vec = v;
684 
685  if constexpr(saturate)
686  {
687  if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
688  {
689  val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
690  }
691  if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
692  {
693  val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 57344.0, -57344.0);
694  }
695  }
696 
697  val.i16_vec =
698  __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
699 
700  return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
701 #endif
702 }
703 
704 template <ck_fp8_interpretation_t interpret,
705  bool saturate,
706  bool stochastic_rounding = false,
709 static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
710 {
711  union
712  {
713  unsigned int i32val;
714  ushortx2_t bhalf_vec;
715  fp8_storage_t i8val[4];
716  } val;
717 
718  constexpr unsigned int i32val = 0;
719  val.bhalf_vec[0] = v;
720 
721  if constexpr(saturate)
722  {
723  if((val.i32val & 0x7FFF) != 0x7FFF)
724  {
725  val.bhalf_vec[0] =
726  ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
727  bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
728  16)); // convert to float and back
729  }
730  }
731 
732  val.i32val = __builtin_amdgcn_cvt_scalef32_sr_fp8_bf16(
733  i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0);
734 
735  return val.i8val[0];
736 }
737 
738 template <ck_fp8_interpretation_t interpret,
739  bool saturate,
740  bool stochastic_rounding = false,
743 static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
744 {
745  // there is no packed conversion with SR, so convert one element at a time
746  return fp8x2_storage_t{
747  cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
748  cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
749 }
750 
751 template <ck_fp8_interpretation_t interpret,
752  bool saturate,
753  bool stochastic_rounding = false,
756 static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
757 {
758  union
759  {
760  unsigned int i32val;
761  ushortx2_t bhalf_vec;
762  fp8_storage_t i8val[4];
763  } val;
764 
765  constexpr unsigned int i32val = 0;
766  val.bhalf_vec[0] = v;
767 
768  if constexpr(saturate)
769  {
770  if((val.i32val & 0x7FFF) != 0x7FFF)
771  {
772  val.bhalf_vec[0] = ushort(
773  (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
774  bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
775  16)); // convert to float and back
776  }
777  }
778 
779  val.i32val = __builtin_amdgcn_cvt_scalef32_sr_bf8_bf16(
780  i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0);
781 
782  return val.i8val[0];
783 }
784 
785 template <ck_fp8_interpretation_t interpret,
786  bool saturate,
787  bool stochastic_rounding = false,
790 static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
791 {
792  // there is no packed conversion with SR, so convert one element at a time
793  return fp8x2_storage_t{
794  cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
795  cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
796 }
797 
798 template <ck_fp8_interpretation_t interpret,
799  bool saturate,
800  bool stochastic_rounding = false,
803 static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
804 {
805  ignore = rng;
806 
807  union
808  {
809  unsigned int i32val;
810  ushortx2_t bhalf_vec;
811  shortx2_t i16_vec;
812  fp8_storage_t i8val[4];
813  } val;
814 
815  constexpr shortx2_t i16x2val = {0, 0};
816  val.bhalf_vec[0] = v;
817 
818  if constexpr(saturate)
819  {
820  if((val.i32val & 0x7FFF) != 0x7FFF)
821  {
822  val.bhalf_vec[0] =
823  ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
824  bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
825  16)); // convert to float and back
826  }
827  }
828 
829  val.i16_vec =
830  __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
831 
832  return val.i8val[0];
833 }
834 
835 template <ck_fp8_interpretation_t interpret,
836  bool saturate,
837  bool stochastic_rounding = false,
840 static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
841 {
842 #if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
843  return fp8x2_storage_t{
844  cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
845  cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
846 #else
847  ignore = rng;
848 
849  union
850  {
851  ushortx2_t bhalf_vec;
852  shortx2_t i16_vec;
853  fp8_storage_t i8val[4];
854  } val;
855 
856  constexpr shortx2_t i16x2val = {0, 0};
857  val.bhalf_vec = v;
858 
859  if constexpr(saturate)
860  {
861  if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
862  {
863  val.bhalf_vec[0] =
864  ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
865  bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
866  16)); // convert to float and back
867  }
868  if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
869  {
870  val.bhalf_vec[1] =
871  ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
872  bit_cast<float>(uint32_t{val.bhalf_vec[1]} << 16), 448.0, -448.0)) >>
873  16)); // convert to float and back
874  }
875  }
876 
877  val.i16_vec =
878  __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
879 
880  return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
881 #endif
882 }
883 
884 template <ck_fp8_interpretation_t interpret,
885  bool saturate,
886  bool stochastic_rounding = false,
889 static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
890 {
891  ignore = rng;
892 
893  union
894  {
895  unsigned int i32val;
896  ushortx2_t bhalf_vec;
897  shortx2_t i16_vec;
898  fp8_storage_t i8val[4];
899  } val;
900 
901  constexpr shortx2_t i16x2val = {0, 0};
902  val.bhalf_vec[0] = v;
903 
904  if constexpr(saturate)
905  {
906  if((val.i32val & 0x7FFF) != 0x7FFF)
907  {
908  val.bhalf_vec[0] = ushort(
909  (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
910  bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
911  16)); // convert to float and back
912  }
913  }
914 
915  val.i16_vec =
916  __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
917 
918  return val.i8val[0];
919 }
920 
921 template <ck_fp8_interpretation_t interpret,
922  bool saturate,
923  bool stochastic_rounding = false,
926 static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
927 {
928  ignore = rng;
929 
930  union
931  {
932  ushortx2_t bhalf_vec;
933  shortx2_t i16_vec;
934  fp8_storage_t i8val[4];
935  } val;
936 
937  constexpr shortx2_t i16x2val = {0, 0};
938  val.bhalf_vec = v;
939 
940  if constexpr(saturate)
941  {
942  if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
943  {
944  val.bhalf_vec[0] = ushort(
945  (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
946  bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
947  16)); // convert to float and back
948  }
949  if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
950  {
951  val.bhalf_vec[1] = ushort(
952  (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
953  bit_cast<float>(uint32_t{val.bhalf_vec[1]} << 16), 57344.0, -57344.0)) >>
954  16)); // convert to float and back
955  }
956  }
957 
958  val.i16_vec =
959  __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
960 
961  return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
962 }
963 #endif // defined(__gfx950__)
964 
965 #if CK_FP8_CVT_FAST_PATH
966 // The conversion function is from rocblas
967 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
968 template <ck_fp8_interpretation_t interpret, bool saturate, bool stochastic_rounding = false>
969 static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
970 {
971  fp8_storage_t i8data;
972  union
973  {
974  float fval;
975  unsigned int i32val;
976  unsigned char i8val[4]; // NOTE: not endian independent
977  } val;
978 
979  unsigned int ival = 0;
980  val.fval = v;
981 
982  if constexpr(saturate)
983  {
984  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
985  {
986  if((val.i32val & 0x7F800000) != 0x7F800000)
987  {
988  val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
989  }
990  }
991  else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
992  { // OCP type
993  if((val.i32val & 0x7F800000) != 0x7F800000)
994  {
995  val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
996  }
997  }
998  else
999  {
1000  if((val.i32val & 0x7F800000) != 0x7F800000)
1001  {
1002  val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
1003  }
1004  }
1005  }
1006 
1007  if constexpr(stochastic_rounding)
1008  {
1009  ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
1011  ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
1012  : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
1013  val.i32val = ival;
1014  i8data = val.i8val[0]; // little endian
1015  }
1016  else
1017  { // RNE CVT
1018  ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
1020  ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
1021  : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
1022  val.fval,
1023  ival,
1024  false); // false -> WORD0
1025  val.i32val = ival;
1026  i8data = val.i8val[0];
1027  }
1028  return i8data;
1029 }
1030 
1031 template <ck_fp8_interpretation_t interpret, bool saturate, bool stochastic_rounding = false>
1032 static __device__ fp8x2_storage_t cast_to_f8_from_f32(float2_t v, unsigned int rng = 0)
1033 {
1034  if constexpr(stochastic_rounding)
1035  {
1036  // there is no packed conversion with SR, so convert one element at a time
1037  return fp8x2_storage_t{
1038  cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[0], rng),
1039  cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[1], rng)};
1040  }
1041  else
1042  {
1043  union
1044  {
1045  float fval;
1046  unsigned int i32val;
1047  unsigned char i8val[4];
1048  } val0, val1;
1049 
1050  val0.fval = v[0];
1051  val1.fval = v[1];
1052 
1053  unsigned int ival = 0;
1054 
1055  if constexpr(saturate)
1056  {
1057  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
1058  {
1059  if((val0.i32val & 0x7F800000) != 0x7F800000)
1060  {
1061  val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 240.0, -240.0);
1062  }
1063  if((val1.i32val & 0x7F800000) != 0x7F800000)
1064  {
1065  val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 240.0, -240.0);
1066  }
1067  }
1068  else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
1069  { // OCP type
1070  if((val0.i32val & 0x7F800000) != 0x7F800000)
1071  {
1072  val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 448.0, -448.0);
1073  }
1074  if((val1.i32val & 0x7F800000) != 0x7F800000)
1075  {
1076  val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 448.0, -448.0);
1077  }
1078  }
1079  else
1080  {
1081  if((val0.i32val & 0x7F800000) != 0x7F800000)
1082  {
1083  val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 57344.0, -57344.0);
1084  }
1085  if((val1.i32val & 0x7F800000) != 0x7F800000)
1086  {
1087  val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 57344.0, -57344.0);
1088  }
1089  }
1090  }
1091 
1092  // RNE CVT
1093  if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
1094  (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP))
1095  {
1096  ival = __builtin_amdgcn_cvt_pk_fp8_f32(val0.fval, val1.fval, ival, false);
1097  }
1098  else
1099  {
1100  ival = __builtin_amdgcn_cvt_pk_bf8_f32(val0.fval, val1.fval, ival, false);
1101  }
1102 
1103  val0.i32val = ival;
1104 
1105  return fp8x2_storage_t{val0.i8val[0], val0.i8val[1]};
1106  }
1107 }
1108 #endif // CK_FP8_CVT_FAST_PATH
1109 
1110 // The conversion function is from rocblas
1111 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
1112 // This has been modified to add double types conversion as well
1113 template <typename T, int wm, int we, bool is_fnuz, bool clip = false, bool stoch = false>
1114 __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rng = 0)
1115 {
1116  constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
1117  constexpr bool is_float = __hip_internal::is_same<T, float>::value;
1118  constexpr bool is_double = __hip_internal::is_same<T, double>::value;
1119  static_assert(is_half || is_float || is_double,
1120  "Only half, float and double can be cast to f8");
1121 
1122  constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
1123 
1124  using T_bitwise = typename ck::conditional_t<
1125  sizeof(T) == 2,
1126  unsigned short int,
1127  typename ck::conditional_t<sizeof(T) == 4, unsigned int, unsigned long long>>;
1128  T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
1129 
1130  unsigned long long x{x_bitwise};
1131 
1132  unsigned long long head, mantissa;
1133  int exponent, bias;
1134  unsigned int sign;
1135  unsigned long long fInf, mask;
1136 
1137  if constexpr(sizeof(T) == 8)
1138  {
1139  head = x & 0xFFF0000000000000ull;
1140  mantissa = x & 0xFFFFFFFFFFFFFull;
1141  exponent = (head >> 52) & 0x7FF;
1142  sign = head >> 63;
1143  bias = 1023;
1144  fInf = 0x7FF0000000000000ull;
1145  mask = 0x7FFFFFFFFFFFFFFFull;
1146  }
1147  else if constexpr(sizeof(T) == 4)
1148  {
1149  head = x & 0xFF800000;
1150  mantissa = x & 0x7FFFFF;
1151  exponent = (head >> 23) & 0xFF;
1152  sign = head >> 31;
1153  bias = 127;
1154  fInf = 0x7F800000;
1155  mask = 0x7FFFFFFF;
1156  }
1157  else
1158  {
1159  head = x & 0xFC00;
1160  mantissa = x & 0x3FF;
1161  exponent = (head >> 10) & 0x1F;
1162  sign = head >> 15;
1163  bias = 15;
1164  fInf = 0x7C00;
1165  mask = 0x7FFF;
1166  }
1167  unsigned int signed_inf = 0;
1168  unsigned int nan = 0;
1169  if constexpr(is_fnuz)
1170  {
1171  signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
1172  nan = 0x80;
1173  }
1174  else
1175  {
1176  if constexpr(we == 4)
1177  { // e4m3
1178  signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
1179  }
1180  else
1181  { // e5m2
1182  signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
1183  }
1184  nan = (sign << 7) + 0x7f;
1185  }
1186  // Max values
1187  unsigned long long ifmax = 0;
1188  if constexpr(sizeof(T) == 8)
1189  {
1190  if constexpr(we == 5)
1191  { // 57344
1192  ifmax = 0x40EC000000000000ull;
1193  }
1194  else
1195  {
1196  if constexpr(is_fnuz)
1197  { // 240
1198  ifmax = 0x406E000000000000ull;
1199  }
1200  else
1201  { // 448
1202  ifmax = 0x407C000000000000ull;
1203  }
1204  }
1205  }
1206  else if(sizeof(T) == 4)
1207  {
1208  if constexpr(we == 5)
1209  {
1210  ifmax = 0x47600000;
1211  }
1212  else
1213  {
1214  if constexpr(is_fnuz)
1215  {
1216  ifmax = 0x43700000;
1217  }
1218  else
1219  {
1220  ifmax = 0x43E00000;
1221  }
1222  }
1223  }
1224  else
1225  {
1226  if constexpr(we == 5)
1227  {
1228  ifmax = 0x7B00;
1229  }
1230  else
1231  {
1232  if constexpr(is_fnuz)
1233  {
1234  ifmax = 0x5B80;
1235  }
1236  else
1237  {
1238  ifmax = 0x5F00;
1239  }
1240  }
1241  }
1242  // Deal with inf and NaNs
1243  if((x & fInf) == fInf)
1244  {
1245  if constexpr(is_fnuz)
1246  return signed_inf;
1247 
1248  return mantissa != 0 ? nan : signed_inf;
1249  }
1250 
1251  if((x & mask) > ifmax)
1252  {
1253  return signed_inf;
1254  }
1255 
1256  if(x == 0)
1257  {
1258  return 0;
1259  }
1260 
1261  // First need to check if it is normal or denorm as there is a difference of
1262  // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
1263  // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
1264  // to mantissa and truncate. And for RNE, no need to add rng. Then probably
1265  // need to check whether there is carry and adjust exponent and mantissa again
1266 
1267  // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
1268  // bits
1269  const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
1270  const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
1271  // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
1272  // f8_exponent is the converted f8 exponent with bias encoding
1273  // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
1274  // the difference needs to be adjusted and mantissa shifted
1275  int act_exponent, f8_exponent, exponent_diff;
1276 
1277  if(exponent == 0)
1278  { // fp32/fp16 is in denormal.
1279  /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
1280  mostly concern fp16 here. In this case, f8 is usually in denormal. But there
1281  could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
1282  exponent bias 16. It means that there are some numbers in fp16 denormal but they
1283  are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
1284  where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
1285  (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
1286  act_exponent = exponent - bias + 1;
1287  exponent_diff = f8_denormal_act_exponent -
1288  act_exponent; // actual exponent is exponent-bias+1 as it is denormal
1289  }
1290  else
1291  { // fp32/fp16 is normal with implicit 1
1292  act_exponent = exponent - bias;
1293  if(act_exponent <= f8_denormal_act_exponent)
1294  {
1295  /* This is the case where fp32/fp16 is normal but it is in f8 denormal
1296  range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
1297  actual exponent is -7, it is actually larger due to the implicit 1,
1298  Therefore it needs to be adjust to -6 and mantissa shift right by 1.
1299  So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
1300  exponent_diff = f8_denormal_act_exponent - act_exponent;
1301  }
1302  else
1303  { // both fp32/fp16 and f8 are in normal range
1304  exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
1305  // for this case, act_exponent could be larger. Just
1306  // that it does not need shift mantissa
1307  }
1308  mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa
1309  }
1310 
1311  bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
1312  (1ull << (mfmt - wm + exponent_diff - 1));
1313  /* This part is a bit tricky. The judgment of whether it is a tie needs to be
1314  done before we shift right as shift right could rip off some residual part and
1315  make something not midpoint look like midpoint. For example, the fp16 number
1316  0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
1317  by 4 bits, it would look like midpoint.
1318  */
1319 
1320  if(exponent_diff > 0)
1321  mantissa >>= exponent_diff;
1322  else if(exponent_diff == -1)
1323  mantissa <<= -exponent_diff;
1324  bool implicit_one = mantissa & (1ull << mfmt);
1325  // if there is no implicit 1, it means the f8 is denormal and need to adjust
1326  // to denorm exponent
1327  f8_exponent =
1328  (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
1329 
1330  // Now we have the exponent and mantissa adjusted
1331  unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
1332  bool odd =
1333  mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1
1334  mantissa +=
1335  (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
1336 
1337  // Now we deal with overflow
1338  if(f8_exponent == 0)
1339  {
1340  if((1ull << mfmt) & mantissa)
1341  {
1342  f8_exponent = 1; // denormal overflow to become normal, promote exponent
1343  }
1344  }
1345  else
1346  {
1347  if((1ull << (mfmt + 1)) & mantissa)
1348  {
1349  mantissa >>= 1;
1350  f8_exponent++;
1351  }
1352  }
1353 
1354  mantissa >>= (mfmt - wm);
1355 
1356  // above range: quantize to maximum possible float of the same sign
1357  const int max_exp = (1 << we) - 1;
1358  if(f8_exponent > max_exp)
1359  {
1360  if constexpr(clip)
1361  {
1362  mantissa = (1 << wm) - 1;
1363  f8_exponent = max_exp;
1364  }
1365  else
1366  {
1367  return signed_inf;
1368  }
1369  }
1370 
1371  if(f8_exponent == 0 && mantissa == 0)
1372  return is_fnuz ? 0 : (sign << 7);
1373  mantissa &= (1 << wm) - 1;
1374  return (sign << 7) | (f8_exponent << wm) | mantissa;
1375 }
1376 
1386 template <ck_fp8_interpretation_t interp,
1388  bool stochastic_rounding = false>
1389 #if CK_FP8_CVT_FAST_PATH
1390 __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
1391 {
1392  __is_interpret_supported(interp);
1393  uint32_t rng = 0;
1394  if constexpr(stochastic_rounding)
1395  {
1396 #if defined(__gfx950__)
1397  // use HW clock for stochastic input multiply by incremented thread id
1398  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1399  (get_thread_global_1d_id() + 1));
1400 #else
1401  constexpr int seed = 1254739;
1402 #ifndef CK_CODE_GEN_RTC
1403  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
1404 #else
1405  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
1406 #endif // #ifndef CK_CODE_GEN_RTC
1407 #endif // #if defined(__gfx950__)
1408  }
1409  return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1410  f, rng);
1411 #else
1412 #if CK_USE_OCP_FP8
1413 __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
1414 {
1415 #else
1416 __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
1417 {
1418 #endif
1419  uint32_t rng = 0;
1420  if constexpr(stochastic_rounding)
1421  {
1422 #if defined(__gfx950__)
1423  // use HW clock for stochastic input multiply by incremented thread id
1424  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1425  (get_thread_global_1d_id() + 1));
1426 #else
1427  constexpr int seed = 1254739;
1428 #ifndef CK_CODE_GEN_RTC
1429  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
1430 #else
1431  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
1432 #endif // #ifndef CK_CODE_GEN_RTC
1433 #endif // #if defined(__gfx950__)
1434  }
1435 
1436  if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
1437  {
1438  return cast_to_f8<float,
1439  3,
1440  4,
1441  true,
1443  stochastic_rounding>(f, rng);
1444  }
1445  else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_FNUZ)
1446  {
1447  return cast_to_f8<float,
1448  2,
1449  5,
1450  true,
1452  stochastic_rounding>(f, rng);
1453  }
1454  else if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
1455  {
1456  return cast_to_f8<float,
1457  3,
1458  4,
1459  false,
1461  stochastic_rounding>(f, rng);
1462  }
1463  else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
1464  {
1465  return cast_to_f8<float,
1466  2,
1467  5,
1468  false,
1470  stochastic_rounding>(f, rng);
1471  }
1472  else
1473  {
1474  __hip_assert(false && "FP8 type is not supported by current target device");
1475  return 0;
1476  }
1477 #endif // CK_FP8_CVT_FAST_PATH
1478 }
1479 
1489 template <ck_fp8_interpretation_t interp,
1491  bool stochastic_rounding = false>
1492 #if CK_FP8_CVT_FAST_PATH
1493 __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
1494 {
1495  __is_interpret_supported(interp);
1496  uint32_t rng = 0;
1497  if constexpr(stochastic_rounding)
1498  {
1499 #if defined(__gfx950__)
1500  // use HW clock for stochastic input multiply by incremented thread id
1501  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1502  (get_thread_global_1d_id() + 1));
1503 #else
1504  constexpr int seed = 1254739;
1505 #ifndef CK_CODE_GEN_RTC
1506  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
1507 #else
1508  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f[0]);
1509 #endif // #ifndef CK_CODE_GEN_RTC
1510 #endif // #if defined(__gfx950__)
1511  }
1512  return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1513  f, rng);
1514 #else
1515 #if CK_USE_OCP_FP8
1516 __host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
1517 {
1518 #else
1519 __host__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
1520 {
1521 #endif // CK_USE_OCP_FP8
1522  return fp8x2_storage_t{cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[0]),
1523  cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[1])};
1524 #endif // CK_FP8_CVT_FAST_PATH
1525 }
1526 
1536 template <ck_fp8_interpretation_t interp,
1538  bool stochastic_rounding = false>
1539 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1540 __host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
1541 #else
1542 __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
1543 #endif
1544 {
1545  {
1546  __is_interpret_supported(interp);
1547  uint32_t rng = 0;
1548  if constexpr(stochastic_rounding)
1549  {
1550 #if defined(__gfx950__)
1551  // use HW clock for stochastic input multiply by incremented thread id
1552  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1553  (get_thread_global_1d_id() + 1));
1554 #else
1555  constexpr int seed = 1254739;
1556 #ifndef CK_CODE_GEN_RTC
1557  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
1558 #else
1559  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
1560 #endif // #ifndef CK_CODE_GEN_RTC
1561 #endif // #if defined(__gfx950__)
1562  }
1563 #if defined(__gfx950__)
1564  return cast_to_f8_from_f16<interp,
1566  stochastic_rounding>(x, rng);
1567 #else
1568  ignore = rng;
1569  return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1570  static_cast<float>(x));
1571 #endif // defined(__gfx950__)
1572  }
1573 }
1574 
1584 template <ck_fp8_interpretation_t interp,
1586  bool stochastic_rounding = false>
1587 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1588 __host__ __device__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
1589 #else
1590 __host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
1591 #endif
1592 {
1593  {
1594  __is_interpret_supported(interp);
1595  uint32_t rng = 0;
1596  if constexpr(stochastic_rounding)
1597  {
1598 #if defined(__gfx950__)
1599  // use HW clock for stochastic input multiply by incremented thread id
1600  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1601  (get_thread_global_1d_id() + 1));
1602 #else
1603  constexpr int seed = 1254739;
1604 #ifndef CK_CODE_GEN_RTC
1605  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
1606 #else
1607  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
1608 #endif // #ifndef CK_CODE_GEN_RTC
1609 #endif // #if defined(__gfx950__)
1610  }
1611 #if defined(__gfx950__)
1612  return cast_to_f8_from_f16<interp,
1614  stochastic_rounding>(x, rng);
1615 #else
1616  ignore = rng;
1617  return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1618  float2_t{static_cast<float>(x[0]), static_cast<float>(x[1])});
1619 #endif // defined(__gfx950__)
1620  }
1621 }
1622 
1632 template <ck_fp8_interpretation_t interp,
1634  bool stochastic_rounding = false>
1635 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1636 __host__ __device__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
1637 #else
1638 __host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
1639 #endif
1640 {
1641  {
1642  __is_interpret_supported(interp);
1643  uint32_t rng = 0;
1644  if constexpr(stochastic_rounding)
1645  {
1646 #if defined(__gfx950__)
1647  // use HW clock for stochastic input multiply by incremented thread id
1648  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1649  (get_thread_global_1d_id() + 1));
1650 #else
1651  constexpr int seed = 1254739;
1652 #ifndef CK_CODE_GEN_RTC
1653  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
1654  static_cast<float>(x));
1655 #else
1656  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), static_cast<float>(x));
1657 #endif // #ifndef CK_CODE_GEN_RTC
1658 #endif // #if defined(__gfx950__)
1659  }
1660 #if defined(__gfx950__)
1661  return cast_to_f8_from_bf16<interp,
1663  stochastic_rounding>(x, rng);
1664 #else
1665  ignore = rng;
1666  return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1667  bit_cast<float>(uint32_t{x} << 16)); // convert value to float
1668 #endif // defined(__gfx950__)
1669  }
1670 }
1671 
1681 template <ck_fp8_interpretation_t interp,
1683  bool stochastic_rounding = false>
1684 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1685 __host__ __device__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
1686 #else
1687 __host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
1688 #endif
1689 {
1690 #if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
1691  return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1692  float2_t{bit_cast<float>(uint32_t{x[0]} << 16),
1693  bit_cast<float>(uint32_t{x[1]} << 16)}); // convert values to float
1694 #else // CK_WORKAROUND_BF16_TO_FP8_CONVERSION
1695  {
1696  __is_interpret_supported(interp);
1697  uint32_t rng = 0;
1698  if constexpr(stochastic_rounding)
1699  {
1700 #if defined(__gfx950__)
1701  // use HW clock for stochastic input multiply by incremented thread id
1702  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1703  (get_thread_global_1d_id() + 1));
1704 #else
1705  constexpr int seed = 1254739;
1706 #ifndef CK_CODE_GEN_RTC
1707  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
1708  static_cast<float>(x[0]));
1709 #else
1710  rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x),
1711  static_cast<float>(x[0]));
1712 #endif // #ifndef CK_CODE_GEN_RTC
1713 #endif // #if defined(__gfx950__)
1714  }
1715 #if defined(__gfx950__)
1716  return cast_to_f8_from_bf16<interp,
1718  stochastic_rounding>(x, rng);
1719 #else
1720  ignore = rng;
1721  return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1722  float2_t{bit_cast<float>(uint32_t{x[0]} << 16),
1723  bit_cast<float>(uint32_t{x[1]} << 16)}); // convert values to float
1724 #endif // defined(__gfx950__)
1725  }
1726 #endif // CK_WORKAROUND_BF16_TO_FP8_CONVERSION
1727 }
1728 
1729 } // namespace fp8_impl
1730 
1731 #if CK_USE_OCP_FP8
1732 using f8_t = f8_ocp_t;
1733 using bf8_t = bf8_ocp_t;
1734 #define CK_FP8_TYPE_FNUZ 0
1735 #define CK_FP8_TYPE_OCP 1
1736 #else
1737 using f8_t = f8_fnuz_t;
1739 #define CK_FP8_TYPE_FNUZ 1
1740 #define CK_FP8_TYPE_OCP 0
1741 #endif
1742 
1743 } // namespace ck
#define __fp8_impl_assert_fnuz_support(interp)
Definition: amd_ck_fp8.hpp:439
#define __fp8_impl_assert_ocp_support(interp)
Definition: amd_ck_fp8.hpp:431
ushort ushortx2_t
Definition: amd_ck_fp8.hpp:65
short shortx2_t
Definition: amd_ck_fp8.hpp:66
float float2_t
Definition: amd_ck_fp8.hpp:67
fp8_storage_t fp8x2_storage_t
Definition: amd_ck_fp8.hpp:63
_Float16 half2_t
Definition: amd_ck_fp8.hpp:64
Definition: ck.hpp:267
__host__ constexpr __device__ Y bit_cast(const X &x)
Definition: type.hpp:306
bf8_fnuz_t bf8_t
Definition: amd_ck_fp8.hpp:1738
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
ck_fp8_interpretation_t
Describes FP8 interpretation.
Definition: amd_ck_fp8.hpp:45
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:54
unsigned _BitInt(8) bf8_fnuz_t
Definition: amd_ck_fp8.hpp:37
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
_BitInt(8) f8_fnuz_t
Definition: amd_ck_fp8.hpp:36
ck_saturation_t
Describes saturation behavior.
Definition: amd_ck_fp8.hpp:56
unsigned char fp8_storage_t
Definition: amd_ck_fp8.hpp:39
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
_W64 unsigned int uintptr_t
Definition: stdint.h:165
unsigned int uint32_t
Definition: stdint.h:126
Definition: amd_ck_fp8.hpp:344
__host__ constexpr __device__ bool operator==(const bf8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:355
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:345
data_type data
Definition: amd_ck_fp8.hpp:346
Definition: amd_ck_fp8.hpp:298
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:299
data_type data
Definition: amd_ck_fp8.hpp:300
__host__ constexpr __device__ bool operator==(const f8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:309