/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/e8m0.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/e8m0.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/e8m0.hpp Source File
e8m0.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/utility/type.hpp"
7 
8 namespace ck {
9 
25 {
26  using type = uint8_t;
28 
29  constexpr static type bias = 127;
30  constexpr static type nan_mask = 0xFF;
31 
32  __host__ __device__ constexpr e8m0_bexp_t() : data{type{}} {}
33  __host__ __device__ constexpr e8m0_bexp_t(type init) : data{init} {}
34  __host__ __device__ constexpr e8m0_bexp_t(int init) : data{static_cast<type>(init & nan_mask)}
35  {
36  }
37  __host__ __device__ explicit constexpr e8m0_bexp_t(float scale)
38  : data{static_cast<type>((bit_cast<uint32_t>(scale) & (nan_mask << 23)) >> 23)}
39  {
40  }
41 
42  __host__ __device__ explicit constexpr operator float() const
43  {
44  if(data == nan_mask || data == 0)
45  {
46  uint32_t bits = data << 1;
47  bits |= 1;
48  bits <<= 22;
49  return bit_cast<float>(bits);
50  }
51  else
52  {
53  uint32_t bits = data << 23;
54  return bit_cast<float>(bits);
55  }
56  }
57 
58  __host__ __device__ constexpr bool operator==(const e8m0_bexp_t& other) const
59  {
60  // strict IEEE compliance for NaN
61  return data == other.data && data != nan_mask;
62  }
63 
64  __host__ __device__ constexpr bool is_nan() const { return data == nan_mask; }
65 };
66 
67 namespace utils {
68 
69 template <typename T>
70 __host__ __device__ inline constexpr int32_t get_exponent_value(T x);
71 
72 template <>
73 __host__ __device__ inline constexpr int32_t get_exponent_value<e8m0_bexp_t>(e8m0_bexp_t x)
74 {
75  return x.data;
76 }
77 
78 } // namespace utils
79 
80 } // namespace ck
__host__ constexpr __device__ int32_t get_exponent_value< e8m0_bexp_t >(e8m0_bexp_t x)
Definition: e8m0.hpp:73
__host__ constexpr __device__ int32_t get_exponent_value(T x)
Definition: mxfp_utils.hpp:35
Definition: ck.hpp:267
__host__ constexpr __device__ Y bit_cast(const X &x)
Definition: type.hpp:306
unsigned int uint32_t
Definition: stdint.h:126
signed int int32_t
Definition: stdint.h:123
unsigned char uint8_t
Definition: stdint.h:124
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
__host__ constexpr __device__ bool is_nan() const
Definition: e8m0.hpp:64
constexpr static type bias
Definition: e8m0.hpp:29
type data
Definition: e8m0.hpp:27
uint8_t type
Definition: e8m0.hpp:26
__host__ constexpr __device__ e8m0_bexp_t(type init)
Definition: e8m0.hpp:33
__host__ constexpr __device__ e8m0_bexp_t(float scale)
Definition: e8m0.hpp:37
constexpr static type nan_mask
Definition: e8m0.hpp:30
__host__ constexpr __device__ e8m0_bexp_t()
Definition: e8m0.hpp:32
__host__ constexpr __device__ e8m0_bexp_t(int init)
Definition: e8m0.hpp:34
__host__ constexpr __device__ bool operator==(const e8m0_bexp_t &other) const
Definition: e8m0.hpp:58