/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/pk_int4.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/pk_int4.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/pk_int4.hpp Source File
pk_int4.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
11 #include <stdint.h>
12 #include <type_traits>
14 
15 #pragma once
16 
17 namespace ck_tile {
18 
19 // Packed 2xint4
20 struct pk_int4_t
21 {
22  using type = int8_t;
24  CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {}
25  CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {}
26 };
27 
28 // limits
29 template <class T>
30 struct numeric;
31 
32 template <>
34 {
35  // minimum finite value, or minimum positive normalized value for float
36  CK_TILE_HOST_DEVICE static constexpr pk_int4_t min()
37  {
38  constexpr uint8_t val = 0b10001000;
39  return pk_int4_t(bit_cast<int8_t>(val));
40  }
41 
42  // minumum finite value
44  {
45  constexpr uint8_t val = 0b10001000;
46  return pk_int4_t(bit_cast<int8_t>(val));
47  }
48 
49  // maximum finite value
50  CK_TILE_HOST_DEVICE static constexpr pk_int4_t max()
51  {
52  constexpr uint8_t val = 0b01110111;
53  return pk_int4_t(bit_cast<int8_t>(val));
54  }
55 
56  // difference between 1.0 and next value representable by float
58  {
59  return 1; // not used
60  }
61 
63  {
64  return 1; // not used
65  }
66 
67  // positive infinity value
69  {
70  return 1; // not used
71  }
72 
73  // quiet NaN
75  {
76  return 1; // not used
77  }
78 
79  // signaling NaN
81  {
82  return 1; // not used
83  }
84 
85  // smallest positive subnormal value
87  {
88  return 1; // not used
89  }
90 
91  CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; }
92 };
93 
94 template <>
96 {
97  static constexpr int PackedSize = 2;
98 };
99 
100 using fp32x2_t = float __attribute__((ext_vector_type(2)));
101 using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
102 using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
103 using int8x2_t = int8_t __attribute__((ext_vector_type(2)));
104 
106 {
107  uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
108 
109  float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
110  float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
111 
112 #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
113  fp32x2_t res = {x_h, x_l};
114 #elif
115  fp32x2_t res = {x_l, x_h};
116 #endif
117  return res;
118 }
119 
121 {
122  uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
123 
124  float x_l = ((x_u8 & 0x0f) >> 0);
125  float x_h = ((x_u8 & 0xf0) >> 4);
126 
127  x_l = x_l > 7 ? x_l - 16 : x_l;
128  x_h = x_l > 7 ? x_l - 16 : x_l;
129 
130 #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
131  fp32x2_t res = {x_h, x_l};
132 #elif
133  fp32x2_t res = {x_l, x_h};
134 #endif
135  return res;
136 }
137 
139 {
140  uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
141 #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
142  uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
143 #elif
144  uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
145 #endif
146  const int EX = 0x64006400;
147  const int SUB = 0xE408E408; //-8
148 
149  int lo = i4s | EX;
150 
151  return pk_add_f16(bit_cast<fp16x2_t>(lo), bit_cast<fp16x2_t>(SUB));
152 }
153 
155 {
156  uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
157 
158  float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
159  float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
160 
161 #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
162  bf16x2_t res = {type_convert<bf16_t>(x_h), type_convert<bf16_t>(x_l)};
163 #elif
164  bf16x2_t res = {type_convert<bf16_t>(x_l), type_convert<bf16_t>(x_h)};
165 #endif
166  return res;
167 }
168 
170 {
171  uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
172 
173  int8_t x_l = (x_u8 & 0x0F);
174  int8_t x_h = (x_u8 & 0xF0) >> 4;
175 
176  if(x_l & 0x08)
177  x_l |= 0xF0;
178  if(x_h & 0x08)
179  x_h |= 0xF0;
180 
181 #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
182  int8x2_t res = {x_h, x_l};
183 #else
184  int8x2_t res = {x_l, x_h};
185 #endif
186  return res;
187 }
188 
189 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
ushort bfloat16_t
Definition: bfloat16.hpp:111
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:105
bfloat16_t bf16x2_t
Definition: pk_fp4.hpp:24
int8_t int8_t
Definition: int8.hpp:20
float fp32x2_t
Definition: pk_fp4.hpp:22
_Float16 fp16x2_t
Definition: half.hpp:385
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t &x)
Definition: pk_int4.hpp:120
int8_t int8x2_t
Definition: pk_int4.hpp:103
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t &x, const fp16x2_t &y)
Definition: half.hpp:387
CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:169
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:138
CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:154
unsigned int uint32_t
Definition: stdint.h:126
unsigned char uint8_t
Definition: stdint.h:124
static constexpr CK_TILE_HOST_DEVICE pk_int4_t denorm_min()
Definition: pk_int4.hpp:86
static constexpr CK_TILE_HOST_DEVICE pk_int4_t zero()
Definition: pk_int4.hpp:91
static constexpr CK_TILE_HOST_DEVICE pk_int4_t quiet_NaN()
Definition: pk_int4.hpp:74
static constexpr CK_TILE_HOST_DEVICE pk_int4_t min()
Definition: pk_int4.hpp:36
static constexpr CK_TILE_HOST_DEVICE pk_int4_t infinity()
Definition: pk_int4.hpp:68
static constexpr CK_TILE_HOST_DEVICE pk_int4_t max()
Definition: pk_int4.hpp:50
static constexpr CK_TILE_HOST_DEVICE pk_int4_t epsilon()
Definition: pk_int4.hpp:57
static constexpr CK_TILE_HOST_DEVICE pk_int4_t lowest()
Definition: pk_int4.hpp:43
static constexpr CK_TILE_HOST_DEVICE pk_int4_t round_error()
Definition: pk_int4.hpp:62
static constexpr CK_TILE_HOST_DEVICE pk_int4_t signaling_NaN()
Definition: pk_int4.hpp:80
Definition: numeric.hpp:81
static constexpr int PackedSize
Definition: numeric.hpp:82
Definition: numeric.hpp:18
Definition: pk_int4.hpp:21
type data
Definition: pk_int4.hpp:23
constexpr CK_TILE_HOST_DEVICE pk_int4_t()
Definition: pk_int4.hpp:24
constexpr CK_TILE_HOST_DEVICE pk_int4_t(type init)
Definition: pk_int4.hpp:25
int8_t type
Definition: pk_int4.hpp:22