/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/mxf8_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/mxf8_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/mxf8_utils.hpp Source File
mxf8_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
6 
7 #if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
8 #define CK_MX_FP8_CVT_FAST_PATH 1
9 #else
10 #define CK_MX_FP8_CVT_FAST_PATH 0
11 #endif
12 
13 namespace ck {
14 
15 namespace fp8_impl {
16 #if CK_MX_FP8_CVT_FAST_PATH
17 template <ck_fp8_interpretation_t interpret>
18 static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v)
19 {
20  union
21  {
22  unsigned int i32val;
23  unsigned char i8val[4];
24  } val;
25  val.i8val[0] = v;
26 
27  static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
29  "Only OCP interpretations are supported");
30 
31  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
32  {
33  return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0);
34  }
35  else
36  {
37  return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0);
38  }
39 }
40 
41 template <ck_fp8_interpretation_t interpret>
42 static __device__ float2_t cast_to_f32_from_f8_scaled(float scale, fp8x2_storage_t v)
43 {
44  const auto i16val = bit_cast<uint16_t>(v);
45 
46  static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
48  "Only OCP interpretations are supported");
49 
50  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
51  {
52  return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0);
53  }
54  else
55  {
56  return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0);
57  }
58 }
59 
60 template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
61 static __device__ fp8_storage_t cast_to_f8_from_f32_scaled(float v,
62  unsigned int rng = 0,
63  float scale = 1.0f)
64 {
65  fp8_storage_t i8data;
66  union
67  {
68  float fval;
69  unsigned int i32val;
70  } val;
71 
72  union
73  {
74  uint32_t ival;
75  vector_type<int16_t, 2>::type v2i16;
76  fp8_storage_t v4i8[4];
77  } ret{};
78 
79  // unsigned int ival = 0;
80  val.fval = v;
81 
82  if constexpr(stochastic_rounding)
83  {
84  ret.ival =
86  ? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0)
87  : __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0);
88 
89  i8data = ret.v4i8[0];
90  }
91  else
92  {
93  // RNE CVT
94  // llvm.amdgcn.cvt.scalef32.pk.fp8.f32
95  // v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
96  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
97  {
98  // If fval / scale > max fp8, returns Nan
99  ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
100  val.fval,
101  val.fval,
102  scale,
103  /*dst_lo_hi_sel*/ false);
104  }
105  else
106  {
107  // If fval / scale > max bf8, returns Inf
108  ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
109  val.fval,
110  val.fval,
111  scale,
112  /*dst_lo_hi_sel*/ false);
113  }
114 
115  i8data = ret.v4i8[0];
116  }
117  return i8data;
118 }
119 
120 template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
121 static __device__ fp8x2_storage_t cast_to_f8_from_f32_scaled(float2_t v,
122  unsigned int rng = 0,
123  float scale = 1.0f)
124 {
125 
126  union
127  {
128  uint32_t ival;
129  vector_type<int16_t, 2>::type v2i16;
130  StaticallyIndexedArray<fp8x2_storage_t, 2> v2f8x2;
131  } ret{};
132 
133  if constexpr(stochastic_rounding)
134  {
135  fp8x2_storage_t f8x2;
136  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
137  {
138  ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0);
139  f8x2[0] = ret.v2f8x2(Number<0>{})[0];
140  ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0);
141  f8x2[1] = ret.v2f8x2(Number<0>{})[0];
142  }
143  else
144  {
145  ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0);
146  f8x2[0] = ret.v2f8x2(Number<0>{})[0];
147  ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0);
148  f8x2[1] = ret.v2f8x2(Number<0>{})[0];
149  }
150  return f8x2;
151  }
152  else
153  {
154  // RNE CVT
155  // llvm.amdgcn.cvt.scalef32.pk.fp8.f32
156  // v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
157  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
158  {
159  // If fval / scale > max fp8, returns Nan
160  ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
161  v[0],
162  v[1],
163  scale,
164  /*dst_lo_hi_sel*/ false);
165  }
166  else
167  {
168  // If fval / scale > max bf8, returns Inf
169  ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
170  v[0],
171  v[1],
172  scale,
173  /*dst_lo_hi_sel*/ false);
174  }
175 
176  return ret.v2f8x2(Number<0>{});
177  }
178 }
179 
180 #endif // CK_MX_FP8_CVT_FAST_PATH
181 
182 #if CK_MX_FP8_CVT_FAST_PATH
193 template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
194 __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
195 {
196  __is_interpret_supported(interp);
197  uint32_t rng = 0;
198  if constexpr(stochastic_rounding)
199  {
200  // use HW clock for stochastic input multiply by incremented thread id
201  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
202  (get_thread_global_1d_id() + 1));
203  }
204  return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
205 }
206 
217 template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
218 __host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
219  float scale)
220 {
221  __is_interpret_supported(interp);
222  uint32_t rng = 0;
223  if constexpr(stochastic_rounding)
224  {
225  // use HW clock for stochastic input multiply by incremented thread id
226  rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
227  (get_thread_global_1d_id() + 1));
228  }
229  return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
230 }
231 
232 #else
233 
244 template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
245 __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
246 {
247 
248  static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
250  "Only OCP interpretations are supported");
251 
252  uint32_t rng = 0;
253  if constexpr(stochastic_rounding)
254  {
255  constexpr int seed = 1254739;
256  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
257  }
258 
259  if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
260  {
261  return cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f / scale, rng);
262  }
263  else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
264  {
265  return cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f / scale, rng);
266  }
267  else
268  {
269  __hip_assert(false && "FP8 type is not supported by current target device");
270  return 0;
271  }
272 }
273 
284 template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
285 __host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
286  float scale)
287 {
288 
289  static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
291  "Only OCP interpretations are supported");
292 
293  uint32_t rng = 0;
294  if constexpr(stochastic_rounding)
295  {
296  constexpr int seed = 1254739;
297  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
298  }
299 
300  if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
301  {
302  return {cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[0] / scale, rng),
303  cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[1] / scale, rng)};
304  }
305  else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
306  {
307  return {cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[0] / scale, rng),
308  cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[1] / scale, rng)};
309  }
310  else
311  {
312  __hip_assert(false && "FP8 type is not supported by current target device");
313  return 0;
314  }
315 }
316 
317 #endif // CK_MX_FP8_CVT_FAST_PATH
318 
319 } // namespace fp8_impl
320 
321 // Declare a template function for fp8 conversion using SR
322 template <typename Y, typename X>
323 __host__ __device__ constexpr Y mxf8_convert_sr(X x, float scale);
324 
325 // Declare a template function for fp8 conversion using RNE
326 template <typename Y, typename X>
327 __host__ __device__ constexpr Y mxf8_convert_rne(X x, float scale);
328 
329 // convert fp32 to fp8 with rounding to nearest even
330 template <>
331 inline __host__ __device__ f8_ocp_t mxf8_convert_rne<f8_ocp_t, float>(float x, float scale)
332 {
333  return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
334 }
335 
336 // convert fp32 to bf8 with rounding to nearest even
337 template <>
338 inline __host__ __device__ bf8_ocp_t mxf8_convert_rne<bf8_ocp_t, float>(float x, float scale)
339 {
340  return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
341 }
342 
343 // convert fp32x2 to fp8x2 with rounding to nearest even
344 template <>
346  float scale)
347 {
348  return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
349 }
350 
351 // convert fp32x2 to bf8x2 with rounding to nearest even
352 template <>
354  float scale)
355 {
356  return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
357 }
358 
359 // convert fp32x16 to fp8x16 with rounding to nearest even
360 template <>
362  float scale)
363 {
364  union
365  {
366  float16_t float_1x16;
367  float2_t float_2x8[8];
368  } in{x};
369 
370  union
371  {
372  f8x16_ocp_t fp8_1x16;
373  f8x2_ocp_t fp8_2x8[8];
374  } out{};
375 
377  [&](auto i) { out.fp8_2x8[i] = mxf8_convert_rne<f8x2_ocp_t>(in.float_2x8[i], scale); });
378 
379  return out.fp8_1x16;
380 }
381 
382 // convert fp32x16 to bf8x16 with rounding to nearest even
383 template <>
385  float scale)
386 {
387  union
388  {
389  float16_t float_1x16;
390  float2_t float_2x8[8];
391  } in{x};
392 
393  union
394  {
395  bf8x16_ocp_t bf8_1x16;
396  bf8x2_ocp_t bf8_2x8[8];
397  } out{};
398 
400  [&](auto i) { out.bf8_2x8[i] = mxf8_convert_rne<bf8x2_ocp_t>(in.float_2x8[i], scale); });
401 
402  return out.bf8_1x16;
403 }
404 
405 // convert fp32x32 to fp8x32 with rounding to nearest even
406 template <>
408  float scale)
409 {
410  union
411  {
412  float32_t float_1x32;
413  float16_t float_16x2[2];
414  } in{x};
415 
416  union
417  {
418  f8x32_ocp_t fp8_1x32;
419  f8x16_ocp_t fp8_16x2[2];
420  } out{};
421 
423  [&](auto i) { out.fp8_16x2[i] = mxf8_convert_rne<f8x16_ocp_t>(in.float_16x2[i], scale); });
424 
425  return out.fp8_1x32;
426 }
427 
428 // convert fp32x32 to bf8x32 with rounding to nearest even
429 template <>
431  float scale)
432 {
433  union
434  {
435  float32_t float_1x32;
436  float16_t float_16x2[2];
437  } in{x};
438 
439  union
440  {
441  bf8x32_ocp_t bf8_1x32;
442  bf8x16_ocp_t bf8_16x2[2];
443  } out{};
444 
446  [&](auto i) { out.bf8_16x2[i] = mxf8_convert_rne<bf8x16_ocp_t>(in.float_16x2[i], scale); });
447 
448  return out.bf8_1x32;
449 }
450 
451 // convert fp32 to fp8 with stochastic rounding
452 template <>
453 inline __host__ __device__ f8_ocp_t mxf8_convert_sr<f8_ocp_t, float>(float x, float scale)
454 {
455  return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
456 }
457 
458 // convert fp32 to bf8 with stochastic rounding
459 template <>
460 inline __host__ __device__ bf8_ocp_t mxf8_convert_sr<bf8_ocp_t, float>(float x, float scale)
461 {
462  return bf8_ocp_t{
463  fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
464 }
465 
466 // convert fp32x2 to fp8x2 with stochastic rounding
467 template <>
468 inline __host__ __device__ f8x2_ocp_t mxf8_convert_sr<f8x2_ocp_t, float2_t>(float2_t x, float scale)
469 {
470  return f8x2_ocp_t{
471  fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
472 }
473 
474 // convert fp32x2 to bf8x2 with stochastic rounding
475 template <>
477  float scale)
478 {
479  return bf8x2_ocp_t{
480  fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
481 }
482 
483 // convert fp32x16 to fp8x16 with stochastic rounding
484 template <>
486  float scale)
487 {
488  union
489  {
490  float16_t float_1x16;
491  float2_t float_2x8[8];
492  } in{x};
493 
494  union
495  {
496  f8x16_ocp_t fp8_1x16;
497  f8x2_ocp_t fp8_2x8[8];
498  } out{};
499 
501  [&](auto i) { out.fp8_2x8[i] = mxf8_convert_sr<f8x2_ocp_t>(in.float_2x8[i], scale); });
502 
503  return out.fp8_1x16;
504 }
505 
506 // convert fp32x16 to bf8x16 with stochastic rounding
507 template <>
509  float scale)
510 {
511  union
512  {
513  float16_t float_1x16;
514  float2_t float_2x8[8];
515  } in{x};
516 
517  union
518  {
519  bf8x16_ocp_t bf8_1x16;
520  bf8x2_ocp_t bf8_2x8[8];
521  } out{};
522 
524  [&](auto i) { out.bf8_2x8[i] = mxf8_convert_sr<bf8x2_ocp_t>(in.float_2x8[i], scale); });
525 
526  return out.bf8_1x16;
527 }
528 
529 // convert fp32x32 to fp8x32 with stochastic rounding
530 template <>
532  float scale)
533 {
534  union
535  {
536  float32_t float_1x32;
537  float16_t float_16x2[2];
538  } in{x};
539 
540  union
541  {
542  f8x32_ocp_t fp8_1x32;
543  f8x16_ocp_t fp8_16x2[2];
544  } out{};
545 
547  [&](auto i) { out.fp8_16x2[i] = mxf8_convert_sr<f8x16_ocp_t>(in.float_16x2[i], scale); });
548 
549  return out.fp8_1x32;
550 }
551 
552 // convert fp32x32 to bf8x32 with stochastic rounding
553 template <>
555  float scale)
556 {
557  union
558  {
559  float32_t float_1x32;
560  float16_t float_16x2[2];
561  } in{x};
562 
563  union
564  {
565  bf8x32_ocp_t bf8_1x32;
566  bf8x16_ocp_t bf8_16x2[2];
567  } out{};
568 
570  [&](auto i) { out.bf8_16x2[i] = mxf8_convert_sr<bf8x16_ocp_t>(in.float_16x2[i], scale); });
571 
572  return out.bf8_1x32;
573 }
574 
575 } // namespace ck
float float2_t
Definition: amd_ck_fp8.hpp:67
fp8_storage_t fp8x2_storage_t
Definition: amd_ck_fp8.hpp:63
Definition: ck.hpp:267
__host__ __device__ f8x16_ocp_t mxf8_convert_sr< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:485
__host__ __device__ f8x2_ocp_t mxf8_convert_rne< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:345
typename vector_type< float, 16 >::type float16_t
Definition: dtype_vector.hpp:2134
__host__ constexpr __device__ Y mxf8_convert_rne(X x, float scale)
__host__ __device__ f8_ocp_t mxf8_convert_rne< f8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:331
typename vector_type< bf8_ocp_t, 32 >::type bf8x32_ocp_t
Definition: dtype_vector.hpp:2198
__host__ __device__ bf8_ocp_t mxf8_convert_sr< bf8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:460
__host__ __device__ f8_ocp_t mxf8_convert_sr< f8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:453
__host__ __device__ bf8x32_ocp_t mxf8_convert_rne< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:430
typename vector_type< float, 2 >::type float2_t
Definition: dtype_vector.hpp:2131
typename vector_type< f8_ocp_t, 2 >::type f8x2_ocp_t
Definition: dtype_vector.hpp:2186
__host__ __device__ bf8_ocp_t mxf8_convert_rne< bf8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:338
typename vector_type< bf8_ocp_t, 2 >::type bf8x2_ocp_t
Definition: dtype_vector.hpp:2194
__host__ constexpr __device__ Y mxf8_convert_sr(X x, float scale)
__host__ __device__ bf8x16_ocp_t mxf8_convert_sr< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:508
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:54
typename vector_type< f8_ocp_t, 32 >::type f8x32_ocp_t
Definition: dtype_vector.hpp:2190
__host__ __device__ bf8x16_ocp_t mxf8_convert_rne< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:384
typename vector_type< f8_ocp_t, 16 >::type f8x16_ocp_t
Definition: dtype_vector.hpp:2189
__host__ __device__ f8x16_ocp_t mxf8_convert_rne< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:361
__host__ __device__ bf8x32_ocp_t mxf8_convert_sr< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:554
__host__ __device__ f8x32_ocp_t mxf8_convert_sr< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:531
__host__ __device__ bf8x2_ocp_t mxf8_convert_sr< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:476
typename vector_type< bf8_ocp_t, 16 >::type bf8x16_ocp_t
Definition: dtype_vector.hpp:2197
__host__ __device__ f8x32_ocp_t mxf8_convert_rne< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:407
typename vector_type< float, 32 >::type float32_t
Definition: dtype_vector.hpp:2135
__host__ __device__ bf8x2_ocp_t mxf8_convert_rne< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:353
__host__ __device__ f8x2_ocp_t mxf8_convert_sr< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:468
unsigned char fp8_storage_t
Definition: amd_ck_fp8.hpp:39
_W64 unsigned int uintptr_t
Definition: stdint.h:165
unsigned int uint32_t
Definition: stdint.h:126
Definition: amd_ck_fp8.hpp:344
Definition: amd_ck_fp8.hpp:298
Definition: functional2.hpp:33