include/ck/utility/mxf8_utils.hpp Source File

include/ck/utility/mxf8_utils.hpp Source File#

Composable Kernel: include/ck/utility/mxf8_utils.hpp Source File
mxf8_utils.hpp
Go to the documentation of this file.
3 
4 #if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
5 #define CK_MX_FP8_CVT_FAST_PATH 1
6 #else
7 #define CK_MX_FP8_CVT_FAST_PATH 0
8 #endif
9 
10 namespace ck {
11 
12 namespace fp8_impl {
13 #if CK_MX_FP8_CVT_FAST_PATH
14 template <ck_fp8_interpretation_t interpret>
15 static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v)
16 {
17  union
18  {
19  unsigned int i32val;
20  unsigned char i8val[4];
21  } val;
22  val.i8val[0] = v;
23 
24  static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
26  "Only OCP interpretations are supported");
27 
28  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
29  {
30  return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0);
31  }
32  else
33  {
34  return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0);
35  }
36 }
37 
38 template <ck_fp8_interpretation_t interpret>
39 static __device__ float2_t cast_to_f32x2_from_f8x2_scaled(float scale, fp8x2_storage_t v)
40 {
41  const auto i16val = bit_cast<uint16_t>(v);
42 
43  static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
45  "Only OCP interpretations are supported");
46 
47  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
48  {
49  return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0);
50  }
51  else
52  {
53  return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0);
54  }
55 }
56 
57 template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
58 static __device__ fp8_storage_t cast_to_f8_from_f32_scaled(float v,
59  unsigned int rng = 0,
60  float scale = 1.0f)
61 {
62  fp8_storage_t i8data;
63  union
64  {
65  float fval;
66  unsigned int i32val;
67  } val;
68 
69  union
70  {
71  uint32_t ival;
72  vector_type<int16_t, 2>::type v2i16;
73  fp8_storage_t v4i8[4];
74  } ret{};
75 
76  // unsigned int ival = 0;
77  val.fval = v;
78 
79  if constexpr(stochastic_rounding)
80  {
81  ret.ival =
83  ? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0)
84  : __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0);
85 
86  i8data = ret.v4i8[0];
87  }
88  else
89  {
90  // RNE CVT
91  // llvm.amdgcn.cvt.scalef32.pk.fp8.f32
92  // v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
93  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
94  {
95  // If fval / scale > max fp8, returns Nan
96  ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
97  val.fval,
98  val.fval,
99  scale,
100  /*dst_lo_hi_sel*/ false);
101  }
102  else
103  {
104  // If fval / scale > max bf8, returns Inf
105  ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
106  val.fval,
107  val.fval,
108  scale,
109  /*dst_lo_hi_sel*/ false);
110  }
111 
112  i8data = ret.v4i8[0];
113  }
114  return i8data;
115 }
116 
117 template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
118 static __device__ fp8x2_storage_t cast_to_f8_from_f32_scaled(float2_t v,
119  unsigned int rng = 0,
120  float scale = 1.0f)
121 {
122 
123  union
124  {
125  uint32_t ival;
126  vector_type<int16_t, 2>::type v2i16;
127  StaticallyIndexedArray<fp8x2_storage_t, 2> v2f8x2;
128  } ret{};
129 
130  if constexpr(stochastic_rounding)
131  {
132  fp8x2_storage_t f8x2;
133  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
134  {
135  ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0);
136  f8x2[0] = ret.v2f8x2(Number<0>{})[0];
137  ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0);
138  f8x2[1] = ret.v2f8x2(Number<0>{})[0];
139  }
140  else
141  {
142  ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0);
143  f8x2[0] = ret.v2f8x2(Number<0>{})[0];
144  ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0);
145  f8x2[1] = ret.v2f8x2(Number<0>{})[0];
146  }
147  return f8x2;
148  }
149  else
150  {
151  // RNE CVT
152  // llvm.amdgcn.cvt.scalef32.pk.fp8.f32
153  // v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
154  if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
155  {
156  // If fval / scale > max fp8, returns Nan
157  ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
158  v[0],
159  v[1],
160  scale,
161  /*dst_lo_hi_sel*/ false);
162  }
163  else
164  {
165  // If fval / scale > max bf8, returns Inf
166  ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
167  v[0],
168  v[1],
169  scale,
170  /*dst_lo_hi_sel*/ false);
171  }
172 
173  return ret.v2f8x2(Number<0>{});
174  }
175 }
176 
177 #endif // CK_MX_FP8_CVT_FAST_PATH
178 
179 #if CK_MX_FP8_CVT_FAST_PATH
190 template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
191 __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
192 {
193  __is_interpret_supported(interp);
194  uint32_t rng = 0;
195  if constexpr(stochastic_rounding)
196  {
197  constexpr int seed = 1254739;
198  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
199  }
200  return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
201 }
202 
213 template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
214 __host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
215  float scale)
216 {
217  __is_interpret_supported(interp);
218  uint32_t rng = 0;
219  if constexpr(stochastic_rounding)
220  {
221  constexpr int seed = 1254739;
222  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
223  }
224  return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
225 }
226 
227 #else
228 
239 template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
240 __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
241 {
242 
243  static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
245  "Only OCP interpretations are supported");
246 
247  uint32_t rng = 0;
248  if constexpr(stochastic_rounding)
249  {
250  constexpr int seed = 1254739;
251  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
252  }
253 
254  if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
255  {
256  return cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f / scale, rng);
257  }
258  else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
259  {
260  return cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f / scale, rng);
261  }
262  else
263  {
264  __hip_assert(false && "FP8 type is not supported by current target device");
265  return 0;
266  }
267 }
268 
279 template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
280 __host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
281  float scale)
282 {
283 
284  static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
286  "Only OCP interpretations are supported");
287 
288  uint32_t rng = 0;
289  if constexpr(stochastic_rounding)
290  {
291  constexpr int seed = 1254739;
292  rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
293  }
294 
295  if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
296  {
297  return {cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[0] / scale, rng),
298  cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[1] / scale, rng)};
299  }
300  else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
301  {
302  return {cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[0] / scale, rng),
303  cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[1] / scale, rng)};
304  }
305  else
306  {
307  __hip_assert(false && "FP8 type is not supported by current target device");
308  return 0;
309  }
310 }
311 
312 #endif // CK_MX_FP8_CVT_FAST_PATH
313 
314 } // namespace fp8_impl
315 
316 // Declare a template function for fp8 conversion using SR
317 template <typename Y, typename X>
318 __host__ __device__ constexpr Y mxf8_convert_sr(X x, float scale);
319 
320 // Declare a template function for fp8 conversion using RNE
321 template <typename Y, typename X>
322 __host__ __device__ constexpr Y mxf8_convert_rne(X x, float scale);
323 
324 // convert fp32 to fp8 with rounding to nearest even
325 template <>
326 inline __host__ __device__ f8_ocp_t mxf8_convert_rne<f8_ocp_t, float>(float x, float scale)
327 {
328  return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
329 }
330 
331 // convert fp32 to bf8 with rounding to nearest even
332 template <>
333 inline __host__ __device__ bf8_ocp_t mxf8_convert_rne<bf8_ocp_t, float>(float x, float scale)
334 {
335  return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
336 }
337 
338 // convert fp32x2 to fp8x2 with rounding to nearest even
339 template <>
341  float scale)
342 {
343  return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
344 }
345 
346 // convert fp32x2 to bf8x2 with rounding to nearest even
347 template <>
349  float scale)
350 {
351  return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
352 }
353 
354 // convert fp32x16 to fp8x16 with rounding to nearest even
355 template <>
357  float scale)
358 {
359  union
360  {
361  float16_t float_1x16;
362  float2_t float_2x8[8];
363  } in{x};
364 
365  union
366  {
367  f8x16_ocp_t fp8_1x16;
368  f8x2_ocp_t fp8_2x8[8];
369  } out{};
370 
372  [&](auto i) { out.fp8_2x8[i] = mxf8_convert_rne<f8x2_ocp_t>(in.float_2x8[i], scale); });
373 
374  return out.fp8_1x16;
375 }
376 
377 // convert fp32x16 to bf8x16 with rounding to nearest even
378 template <>
380  float scale)
381 {
382  union
383  {
384  float16_t float_1x16;
385  float2_t float_2x8[8];
386  } in{x};
387 
388  union
389  {
390  bf8x16_ocp_t bf8_1x16;
391  bf8x2_ocp_t bf8_2x8[8];
392  } out{};
393 
395  [&](auto i) { out.bf8_2x8[i] = mxf8_convert_rne<bf8x2_ocp_t>(in.float_2x8[i], scale); });
396 
397  return out.bf8_1x16;
398 }
399 
400 // convert fp32x32 to fp8x32 with rounding to nearest even
401 template <>
403  float scale)
404 {
405  union
406  {
407  float32_t float_1x32;
408  float16_t float_16x2[2];
409  } in{x};
410 
411  union
412  {
413  f8x32_ocp_t fp8_1x32;
414  f8x16_ocp_t fp8_16x2[2];
415  } out{};
416 
418  [&](auto i) { out.fp8_16x2[i] = mxf8_convert_rne<f8x16_ocp_t>(in.float_16x2[i], scale); });
419 
420  return out.fp8_1x32;
421 }
422 
423 // convert fp32x32 to bf8x32 with rounding to nearest even
424 template <>
426  float scale)
427 {
428  union
429  {
430  float32_t float_1x32;
431  float16_t float_16x2[2];
432  } in{x};
433 
434  union
435  {
436  bf8x32_ocp_t bf8_1x32;
437  bf8x16_ocp_t bf8_16x2[2];
438  } out{};
439 
441  [&](auto i) { out.bf8_16x2[i] = mxf8_convert_rne<bf8x16_ocp_t>(in.float_16x2[i], scale); });
442 
443  return out.bf8_1x32;
444 }
445 
446 // convert fp32 to fp8 with stochastic rounding
447 template <>
448 inline __host__ __device__ f8_ocp_t mxf8_convert_sr<f8_ocp_t, float>(float x, float scale)
449 {
450  return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
451 }
452 
453 // convert fp32 to bf8 with stochastic rounding
454 template <>
455 inline __host__ __device__ bf8_ocp_t mxf8_convert_sr<bf8_ocp_t, float>(float x, float scale)
456 {
457  return bf8_ocp_t{
458  fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
459 }
460 
461 // convert fp32x2 to fp8x2 with stochastic rounding
462 template <>
463 inline __host__ __device__ f8x2_ocp_t mxf8_convert_sr<f8x2_ocp_t, float2_t>(float2_t x, float scale)
464 {
465  return f8x2_ocp_t{
466  fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
467 }
468 
469 // convert fp32x2 to bf8x2 with stochastic rounding
470 template <>
472  float scale)
473 {
474  return bf8x2_ocp_t{
475  fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
476 }
477 
478 // convert fp32x16 to fp8x16 with stochastic rounding
479 template <>
481  float scale)
482 {
483  union
484  {
485  float16_t float_1x16;
486  float2_t float_2x8[8];
487  } in{x};
488 
489  union
490  {
491  f8x16_ocp_t fp8_1x16;
492  f8x2_ocp_t fp8_2x8[8];
493  } out{};
494 
496  [&](auto i) { out.fp8_2x8[i] = mxf8_convert_sr<f8x2_ocp_t>(in.float_2x8[i], scale); });
497 
498  return out.fp8_1x16;
499 }
500 
501 // convert fp32x16 to bf8x16 with stochastic rounding
502 template <>
504  float scale)
505 {
506  union
507  {
508  float16_t float_1x16;
509  float2_t float_2x8[8];
510  } in{x};
511 
512  union
513  {
514  bf8x16_ocp_t bf8_1x16;
515  bf8x2_ocp_t bf8_2x8[8];
516  } out{};
517 
519  [&](auto i) { out.bf8_2x8[i] = mxf8_convert_sr<bf8x2_ocp_t>(in.float_2x8[i], scale); });
520 
521  return out.bf8_1x16;
522 }
523 
524 // convert fp32x32 to fp8x32 with stochastic rounding
525 template <>
527  float scale)
528 {
529  union
530  {
531  float32_t float_1x32;
532  float16_t float_16x2[2];
533  } in{x};
534 
535  union
536  {
537  f8x32_ocp_t fp8_1x32;
538  f8x16_ocp_t fp8_16x2[2];
539  } out{};
540 
542  [&](auto i) { out.fp8_16x2[i] = mxf8_convert_sr<f8x16_ocp_t>(in.float_16x2[i], scale); });
543 
544  return out.fp8_1x32;
545 }
546 
547 // convert fp32x32 to bf8x32 with stochastic rounding
548 template <>
550  float scale)
551 {
552  union
553  {
554  float32_t float_1x32;
555  float16_t float_16x2[2];
556  } in{x};
557 
558  union
559  {
560  bf8x32_ocp_t bf8_1x32;
561  bf8x16_ocp_t bf8_16x2[2];
562  } out{};
563 
565  [&](auto i) { out.bf8_16x2[i] = mxf8_convert_sr<bf8x16_ocp_t>(in.float_16x2[i], scale); });
566 
567  return out.bf8_1x32;
568 }
569 
570 } // namespace ck
float float2_t
Definition: amd_ck_fp8.hpp:67
fp8_storage_t fp8x2_storage_t
Definition: amd_ck_fp8.hpp:66
Definition: ck.hpp:264
__host__ __device__ f8x16_ocp_t mxf8_convert_sr< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:480
__host__ __device__ f8x2_ocp_t mxf8_convert_rne< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:340
typename vector_type< float, 16 >::type float16_t
Definition: data_type.hpp:2484
__host__ constexpr __device__ Y mxf8_convert_rne(X x, float scale)
__host__ __device__ f8_ocp_t mxf8_convert_rne< f8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:326
typename vector_type< bf8_ocp_t, 32 >::type bf8x32_ocp_t
Definition: data_type.hpp:2549
__host__ __device__ bf8_ocp_t mxf8_convert_sr< bf8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:455
__host__ __device__ f8_ocp_t mxf8_convert_sr< f8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:448
__host__ __device__ bf8x32_ocp_t mxf8_convert_rne< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:425
typename vector_type< float, 2 >::type float2_t
Definition: data_type.hpp:2481
typename vector_type< f8_ocp_t, 2 >::type f8x2_ocp_t
Definition: data_type.hpp:2537
__host__ __device__ bf8_ocp_t mxf8_convert_rne< bf8_ocp_t, float >(float x, float scale)
Definition: mxf8_utils.hpp:333
typename vector_type< bf8_ocp_t, 2 >::type bf8x2_ocp_t
Definition: data_type.hpp:2545
__host__ constexpr __device__ Y mxf8_convert_sr(X x, float scale)
__host__ __device__ bf8x16_ocp_t mxf8_convert_sr< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:503
typename vector_type< f8_ocp_t, 32 >::type f8x32_ocp_t
Definition: data_type.hpp:2541
__host__ __device__ bf8x16_ocp_t mxf8_convert_rne< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:379
typename vector_type< f8_ocp_t, 16 >::type f8x16_ocp_t
Definition: data_type.hpp:2540
__host__ __device__ f8x16_ocp_t mxf8_convert_rne< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition: mxf8_utils.hpp:356
__host__ __device__ bf8x32_ocp_t mxf8_convert_sr< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:549
__host__ __device__ f8x32_ocp_t mxf8_convert_sr< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:526
__host__ __device__ bf8x2_ocp_t mxf8_convert_sr< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:471
typename vector_type< bf8_ocp_t, 16 >::type bf8x16_ocp_t
Definition: data_type.hpp:2548
__host__ __device__ f8x32_ocp_t mxf8_convert_rne< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition: mxf8_utils.hpp:402
typename vector_type< float, 32 >::type float32_t
Definition: data_type.hpp:2485
__host__ __device__ bf8x2_ocp_t mxf8_convert_rne< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:348
__host__ __device__ f8x2_ocp_t mxf8_convert_sr< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition: mxf8_utils.hpp:463
unsigned char fp8_storage_t
Definition: amd_ck_fp8.hpp:42
Definition: amd_ck_fp8.hpp:344
Definition: amd_ck_fp8.hpp:298
Definition: functional2.hpp:31