/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/pk_fp4.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/pk_fp4.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/pk_fp4.hpp Source File
pk_fp4.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cmath>
10 
11 #if defined(__gfx950__)
12 #define CK_TILE_FP4_CVT_DEVICE 1
13 #else
14 #define CK_TILE_FP4_CVT_DEVICE 0
15 #endif
16 
17 #define TEST_convert_with_table 0
18 
19 namespace ck_tile {
20 
21 using fp32_t = float;
22 using fp32x2_t = float __attribute__((ext_vector_type(2)));
23 using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
24 using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
25 
26 struct pk_float4_e2m1_t;
27 CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t float_to_pk_fp4(const float& x, float scale = 1.f);
28 
29 // TODO: Add stochastic method
31 {
32  // TODO: Can we merge raw_type and type?
33  using raw_type = uint8_t;
34  using type = raw_type;
36 
38  template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
39  CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast<type>(init)}
40  {
41  }
42  CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f)
43  : data{float_to_pk_fp4(init, scale)}
44  {
45  }
46  CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
47  CK_TILE_HOST_DEVICE constexpr type& get() { return data; }
48  CK_TILE_HOST_DEVICE constexpr type get() const { return data; }
49 
50  CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const;
51  CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const;
52  CK_TILE_HOST_DEVICE constexpr fp16_t to_fp16(float scale = 1.f) const;
53  CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const;
54  CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const;
55  CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const;
56 
57  CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); }
58  CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
59  CK_TILE_HOST_DEVICE constexpr operator fp16_t() const { return to_fp16(); }
60  CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); }
61  CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); }
62  CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
63 
64  template <index_t I>
66  {
67  return _unpack(number<I>{});
68  }
70  const pk_float4_e2m1_t& x1)
71  {
72  return _pack(x0.get(), x1.get());
73  }
74 
75  template <index_t I>
77  CK_TILE_HOST_DEVICE constexpr static type _pack(const type x0, const type x1)
78  {
79  return (x1 << 4) | (x0 & 0b00001111);
80  }
81 
82 #if TEST_convert_with_table
83  static constexpr float e2m1_to_fp32_table[16] = {
84  0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6};
85  static constexpr fp16_t e2m1_to_fp16_table[16] = {
86  bit_cast<fp16_t>(static_cast<uint16_t>(0x0000)), // 0
87  bit_cast<fp16_t>(static_cast<uint16_t>(0x3800)), // 0.5
88  bit_cast<fp16_t>(static_cast<uint16_t>(0x3C00)), // 1
89  bit_cast<fp16_t>(static_cast<uint16_t>(0x3E00)), // 1.5
90  bit_cast<fp16_t>(static_cast<uint16_t>(0x4000)), // 2
91  bit_cast<fp16_t>(static_cast<uint16_t>(0x4200)), // 3
92  bit_cast<fp16_t>(static_cast<uint16_t>(0x4400)), // 4
93  bit_cast<fp16_t>(static_cast<uint16_t>(0x4600)), // 6
94  bit_cast<fp16_t>(static_cast<uint16_t>(0x8000)), // -0
95  bit_cast<fp16_t>(static_cast<uint16_t>(0xB800)), // -0.5
96  bit_cast<fp16_t>(static_cast<uint16_t>(0xBC00)), // -1
97  bit_cast<fp16_t>(static_cast<uint16_t>(0xBE00)), // -1.5
98  bit_cast<fp16_t>(static_cast<uint16_t>(0xC000)), // -2
99  bit_cast<fp16_t>(static_cast<uint16_t>(0xC200)), // -3
100  bit_cast<fp16_t>(static_cast<uint16_t>(0xC400)), // -4
101  bit_cast<fp16_t>(static_cast<uint16_t>(0xC600)) // -6
102  };
103 #endif
104 };
105 
107 using pk_fp4_raw_t = typename pk_fp4_t::type;
108 
109 template <>
111 {
113 
114  static constexpr int exp = 2;
115  static constexpr int mant = 1;
116  static constexpr int bias = 1;
117  static constexpr int PackedSize = 2;
118 };
119 
120 // limits
121 template <class T>
122 struct numeric;
123 
124 template <>
126 {
127  static constexpr pk_fp4_raw_t binary_min_normal = 0b00100010; // 1
128  static constexpr pk_fp4_raw_t binary_max_normal = 0b01110111; // 6
129  static constexpr pk_fp4_raw_t binary_lowest_normal = 0b11111111; // -6
130  static constexpr pk_fp4_raw_t binary_min_subnorm = 0b00010001; // 0.5
131  static constexpr pk_fp4_raw_t binary_max_subnorm = 0b00010001; // 0.5
132  static constexpr pk_fp4_raw_t binary_zero = 0b00000000; // 0
133  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t min() { return binary_min_normal; }
134  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t max() { return binary_max_normal; }
135  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t lowest() { return binary_lowest_normal; }
136  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t epsilon() { return binary_min_subnorm; }
137  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t round_error() { return binary_min_subnorm; }
138  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; }
139  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t denorm_min() { return binary_min_subnorm; }
140 
141  CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
142  // N/A
143  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t infinity() { return max(); }
144  // N/A
145  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t quiet_NaN() { return max(); }
146  // N/A
147  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t signaling_NaN() { return max(); }
148 };
149 
150 template <index_t I>
151 CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::_unpack(number<I>) const
152 {
153  static_assert(I < 2, "Index is out of range.");
154  if constexpr(I == 1)
155  return (data >> 4);
156  else
157  return data & 0b00001111;
158 }
160 // TODO: consider replace this macro to improve performance
161 
162 #if CK_TILE_FP4_CVT_DEVICE
163 namespace impl {
164 
165 template <typename T>
166 CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
167 {
168  if constexpr(std::is_same_v<T, fp32_t>)
169  return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0];
170  else if constexpr(std::is_same_v<T, fp32x2_t>)
171  return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
172  else if constexpr(std::is_same_v<T, fp16_t>)
173  return fp16x2_t(__builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0))[0];
174  else if constexpr(std::is_same_v<T, fp16x2_t>)
175  return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
176  else if constexpr(std::is_same_v<T, bf16_t>)
177  return bf16x2_t(__builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0))[0];
178  else if constexpr(std::is_same_v<T, bf16x2_t>)
179  return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
180  else
181  static_assert(std::false_type::value, "Unsupported type.");
182  return T{};
183 }
184 template <typename T>
185 CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
186 {
187  union
188  {
189  uint32_t u32;
190  pk_fp4_raw_t pf4[4];
191  } cvt{0};
192  if constexpr(std::is_same_v<T, fp32_t>)
193  cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0);
194  else if constexpr(std::is_same_v<T, fp32x2_t>)
195  cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src[0], src[1], scale, 0);
196  else if constexpr(std::is_same_v<T, fp16_t>)
197  cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, fp16x2_t{src, src}, scale, 0);
198  else if constexpr(std::is_same_v<T, fp16x2_t>)
199  cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, src, scale, 0);
200  else if constexpr(std::is_same_v<T, bf16_t>)
201  cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, bf16x2_t{src, src}, scale, 0);
202  else if constexpr(std::is_same_v<T, bf16x2_t>)
203  cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, src, scale, 0);
204  else
205  static_assert(std::false_type::value, "Unsupported type.");
206  return cvt.pf4[0];
207 }
208 
209 } // namespace impl
210 #endif
211 
212 CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const
213 {
214 #if CK_TILE_FP4_CVT_DEVICE
215  return impl::_from_f4<bf16_t>(data, scale);
216 #else
217  return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale))};
218 #endif
219 }
220 
222 {
223 #if CK_TILE_FP4_CVT_DEVICE
224  return impl::_from_f4<bf16x2_t>(data, scale);
225 #else
226  return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale)),
227  type_convert<bf16_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
228 #endif
229 }
230 
231 // TODO: make it generic so that we can convert from directrly.
232 CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_mxfp4(float x, float scale)
233 {
234 #if CK_TILE_FP4_CVT_DEVICE
235  return impl::_to_f4(x, scale);
236 #else
237  return convert_to_type<pk_fp4_t>(x, scale);
238 #endif
239 }
240 CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale)
241 {
242 #if CK_TILE_FP4_CVT_DEVICE
243  return impl::_to_f4(x, scale);
244 #else
245  auto res = convert_to_type<pk_fp4_t>(x, scale);
246  return pk_fp4_t::_pack(res, res);
247 #endif
248 }
249 CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
250 {
251 #if CK_TILE_FP4_CVT_DEVICE
252  return impl::_to_f4(x, scale);
253 #else
254  auto res = float_to_mxfp4(type_convert<float>(x), scale);
255  return pk_fp4_t::_pack(res, res);
256 #endif
257 }
258 CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
259 {
260 #if CK_TILE_FP4_CVT_DEVICE
261  return impl::_to_f4(x, scale);
262 #else
263  auto res = float_to_mxfp4(type_convert<float>(x), scale);
264  return pk_fp4_t::_pack(res, res);
265 #endif
266 }
267 CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale)
268 {
269 #if CK_TILE_FP4_CVT_DEVICE
270  return impl::_to_f4(x, scale);
271 #else
272  return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
273 #endif
274 }
275 CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
276 {
277 #if CK_TILE_FP4_CVT_DEVICE
278  return impl::_to_f4(x, scale);
279 #else
280  return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
281 #endif
282 }
283 CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
284 {
285 #if CK_TILE_FP4_CVT_DEVICE
286  return impl::_to_f4(x, scale);
287 #else
288  return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
289 #endif
290 }
291 
292 CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale)
293 {
294  return x.to_fp32x2(scale);
295 }
296 CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale)
297 {
298  return x.to_fp16x2(scale);
299 }
300 CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale)
301 {
302  return x.to_bf16x2(scale);
303 }
304 CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale)
305 {
306  return x.to_float(scale);
307 }
308 CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale)
309 {
310  return x.to_fp16(scale);
311 }
312 CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale)
313 {
314  return x.to_bf16(scale);
315 }
316 
317 #if TEST_convert_with_table == 0
318 CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
319 {
320 #if CK_TILE_FP4_CVT_DEVICE
321  return impl::_from_f4<fp32_t>(data, scale);
322 #else
323  return convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale);
324 #endif
325 }
327 {
328 #if CK_TILE_FP4_CVT_DEVICE
329  return impl::_from_f4<fp32x2_t>(data, scale);
330 #else
331  return fp32x2_t{convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale),
332  convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale)};
333 #endif
334 }
335 
336 CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
337 {
338 #if CK_TILE_FP4_CVT_DEVICE
339  return impl::_from_f4<fp16_t>(data, scale);
340 #else
341  return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale))};
342 #endif
343 }
345 {
346 #if CK_TILE_FP4_CVT_DEVICE
347  return impl::_from_f4<fp16x2_t>(data, scale);
348 #else
349  return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale)),
350  type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
351 #endif
352 }
353 #else
354 CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
355 {
356  return e2m1_to_fp32_table[_unpack(number<0>{})] * scale;
357 }
358 CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
359 {
360  return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale};
361 }
362 CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
363 {
364  return type_convert<float>(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale;
365 }
366 CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
367 {
368  return fp16x2_t{
369  type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale),
370  type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[_unpack(number<1>{})]) *
371  scale)};
372 }
373 #endif
374 
375 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
typename pk_fp4_t::type pk_fp4_raw_t
Definition: pk_fp4.hpp:107
ushort bfloat16_t
Definition: bfloat16.hpp:111
constexpr CK_TILE_HOST_DEVICE pk_fp4_t fp16_to_pk_fp4(const fp16_t &x, float scale)
Definition: pk_fp4.hpp:249
bfloat16_t bf16x2_t
Definition: pk_fp4.hpp:24
_Float16 fp16_t
Definition: half.hpp:110
float fp32x2_t
Definition: pk_fp4.hpp:22
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
constexpr CK_TILE_HOST_DEVICE pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t &x, float scale)
Definition: pk_fp4.hpp:283
pk_float4_e2m1_t pk_fp4_t
Definition: pk_fp4.hpp:106
float fp32_t
Definition: pk_fp4.hpp:21
_Float16 fp16x2_t
Definition: half.hpp:385
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t float_to_pk_fp4(const float &x, float scale=1.f)
Definition: pk_fp4.hpp:240
constexpr CK_TILE_HOST_DEVICE float pk_fp4_to_float(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:304
constexpr CK_TILE_HOST_DEVICE fp16_t pk_fp4_to_fp16(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:308
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:411
constexpr CK_TILE_HOST_DEVICE pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t &x, float scale)
Definition: pk_fp4.hpp:275
constexpr CK_TILE_HOST_DEVICE pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t &x, float scale)
Definition: pk_fp4.hpp:267
constexpr CK_TILE_HOST_DEVICE fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:292
constexpr CK_TILE_HOST_DEVICE pk_fp4_raw_t float_to_mxfp4(float x, float scale)
Definition: pk_fp4.hpp:232
constexpr CK_TILE_HOST_DEVICE pk_fp4_t bf16_to_pk_fp4(const bf16_t &x, float scale)
Definition: pk_fp4.hpp:258
constexpr CK_TILE_HOST_DEVICE fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:296
constexpr CK_TILE_HOST_DEVICE bf16_t pk_fp4_to_bf16(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:312
constexpr CK_TILE_HOST_DEVICE bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:300
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned short uint16_t
Definition: stdint.h:125
unsigned int uint32_t
Definition: stdint.h:126
unsigned char uint8_t
Definition: stdint.h:124
Definition: integral_constant.hpp:13
static constexpr CK_TILE_HOST_DEVICE bool has_inf()
Definition: pk_fp4.hpp:141
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t min()
Definition: pk_fp4.hpp:133
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t denorm_min()
Definition: pk_fp4.hpp:139
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t infinity()
Definition: pk_fp4.hpp:143
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t round_error()
Definition: pk_fp4.hpp:137
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t epsilon()
Definition: pk_fp4.hpp:136
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t zero()
Definition: pk_fp4.hpp:138
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t quiet_NaN()
Definition: pk_fp4.hpp:145
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t signaling_NaN()
Definition: pk_fp4.hpp:147
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t lowest()
Definition: pk_fp4.hpp:135
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t max()
Definition: pk_fp4.hpp:134
pk_fp4_raw_t bitwise_type
Definition: pk_fp4.hpp:112
Definition: numeric.hpp:81
static constexpr int PackedSize
Definition: numeric.hpp:82
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26
Definition: pk_fp4.hpp:31
constexpr CK_TILE_HOST_DEVICE bf16x2_t to_bf16x2(float scale=1.f) const
Definition: pk_fp4.hpp:221
constexpr CK_TILE_HOST_DEVICE fp16x2_t to_fp16x2(float scale=1.f) const
Definition: pk_fp4.hpp:344
constexpr CK_TILE_HOST_DEVICE fp16_t to_fp16(float scale=1.f) const
Definition: pk_fp4.hpp:336
constexpr CK_TILE_HOST_DEVICE float to_float(float scale=1.f) const
Definition: pk_fp4.hpp:318
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t()
Definition: pk_fp4.hpp:37
uint8_t raw_type
Definition: pk_fp4.hpp:33
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t(float init, float scale=1.f)
Definition: pk_fp4.hpp:42
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t unpack(number< I >) const
Definition: pk_fp4.hpp:65
constexpr CK_TILE_HOST_DEVICE type & get()
Definition: pk_fp4.hpp:47
constexpr CK_TILE_HOST_DEVICE type _unpack(number< I >) const
constexpr CK_TILE_HOST_DEVICE type get() const
Definition: pk_fp4.hpp:48
constexpr CK_TILE_HOST_DEVICE fp32x2_t to_fp32x2(float scale=1.f) const
Definition: pk_fp4.hpp:326
constexpr CK_TILE_HOST_DEVICE bf16_t to_bf16(float scale=1.f) const
Definition: pk_fp4.hpp:212
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t(T init)
Definition: pk_fp4.hpp:39
type data
Definition: pk_fp4.hpp:35
raw_type type
Definition: pk_fp4.hpp:34
constexpr static CK_TILE_HOST_DEVICE pk_float4_e2m1_t pack(const pk_float4_e2m1_t &x0, const pk_float4_e2m1_t &x1)
Definition: pk_fp4.hpp:69
constexpr static CK_TILE_HOST_DEVICE type _pack(const type x0, const type x1)
Definition: pk_fp4.hpp:77
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:106