28 vy0.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I0];
29 vy0.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I0];
31 vy1.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I1];
32 vy1.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I1];
34 y0 = vy0.template AsType<half2_t>()[I0];
35 y1 = vy1.template AsType<half2_t>()[I0];
37 constexpr
int32_t m0 = 0x05040100;
38 constexpr
int32_t m1 = 0x07060302;
44 y0 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0));
45 y1 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1));
49 template <index_t NX, index_t NY>
66 static_assert((NX % 2 == 0 && NY % 2 == 0),
"wrong!");
72 const auto& x_s2_0 = vx_tuple[ix].template AsType<half2_t>()[iy / I2];
73 const auto& x_s2_1 = vx_tuple[ix + I1].template AsType<half2_t>()[iy / I2];
76 auto& y_s2_0 = vy_tuple(iy).template AsType<half2_t>()(ix / I2);
77 auto& y_s2_1 = vy_tuple(iy + I1).template AsType<half2_t>()(ix / I2);
98 constexpr
int32_t m0 = 0x05010400;
99 constexpr
int32_t m1 = 0x05040100;
100 constexpr
int32_t m2 = 0x07060302;
101 constexpr
int32_t m3 = 0x07030602;
107 t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
108 t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
109 z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
110 z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
111 t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
112 t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
113 z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
114 z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
116 y0 = bit_cast<int8x4_t>(z0);
117 y1 = bit_cast<int8x4_t>(z1);
118 y2 = bit_cast<int8x4_t>(z2);
119 y3 = bit_cast<int8x4_t>(z3);
122 template <index_t NX, index_t NY>
141 static_assert((NX % 4 == 0 && NY % 4 == 0),
"wrong!");
147 const auto& x_s4_0 = vx_tuple[ix].template AsType<int8x4_t>()[iy / I4];
148 const auto& x_s4_1 = vx_tuple[ix + I1].template AsType<int8x4_t>()[iy / I4];
149 const auto& x_s4_2 = vx_tuple[ix + I2].template AsType<int8x4_t>()[iy / I4];
150 const auto& x_s4_3 = vx_tuple[ix + I3].template AsType<int8x4_t>()[iy / I4];
153 auto& y_s4_0 = vy_tuple(iy).template AsType<int8x4_t>()(ix / I4);
154 auto& y_s4_1 = vy_tuple(iy + I1).template AsType<int8x4_t>()(ix / I4);
155 auto& y_s4_2 = vy_tuple(iy + I2).template AsType<int8x4_t>()(ix / I4);
156 auto& y_s4_3 = vy_tuple(iy + I3).template AsType<int8x4_t>()(ix / I4);
177 constexpr
int32_t m0 = 0x05010400;
178 constexpr
int32_t m1 = 0x05040100;
179 constexpr
int32_t m2 = 0x07060302;
180 constexpr
int32_t m3 = 0x07030602;
186 t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
187 t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
188 z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
189 z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
190 t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
191 t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
192 z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
193 z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
195 y0 = bit_cast<f8x4_t>(z0);
196 y1 = bit_cast<f8x4_t>(z1);
197 y2 = bit_cast<f8x4_t>(z2);
198 y3 = bit_cast<f8x4_t>(z3);
201 template <index_t NX, index_t NY>
220 static_assert((NX % 4 == 0 && NY % 4 == 0),
"wrong!");
226 const auto& x_s4_0 = vx_tuple[ix].template AsType<f8x4_t>()[iy / I4];
227 const auto& x_s4_1 = vx_tuple[ix + I1].template AsType<f8x4_t>()[iy / I4];
228 const auto& x_s4_2 = vx_tuple[ix + I2].template AsType<f8x4_t>()[iy / I4];
229 const auto& x_s4_3 = vx_tuple[ix + I3].template AsType<f8x4_t>()[iy / I4];
232 auto& y_s4_0 = vy_tuple(iy).template AsType<f8x4_t>()(ix / I4);
233 auto& y_s4_1 = vy_tuple(iy + I1).template AsType<f8x4_t>()(ix / I4);
234 auto& y_s4_2 = vy_tuple(iy + I2).template AsType<f8x4_t>()(ix / I4);
235 auto& y_s4_3 = vy_tuple(iy + I3).template AsType<f8x4_t>()(ix / I4);
238 transpose_f8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3);
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
_Float16 half_t
Definition: data_type.hpp:30
__device__ void transpose_f8_4x4(const f8x4_t &x0, const f8x4_t &x1, const f8x4_t &x2, const f8x4_t &x3, f8x4_t &y0, f8x4_t &y1, f8x4_t &y2, f8x4_t &y3)
Definition: transpose_vectors.hpp:166
typename vector_type< half_t, 2 >::type half2_t
Definition: dtype_vector.hpp:2139
int32_t index_t
Definition: ck.hpp:298
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2163
__device__ void transpose_int8_4x4(const int8x4_t &x0, const int8x4_t &x1, const int8x4_t &x2, const int8x4_t &x3, int8x4_t &y0, int8x4_t &y1, int8x4_t &y2, int8x4_t &y3)
Definition: transpose_vectors.hpp:87
__device__ void transpose_fp16_2x2(const half2_t &x0, const half2_t &x1, half2_t &y0, half2_t &y1)
Definition: transpose_vectors.hpp:19
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
f8_t S
Definition: transpose_vectors.hpp:208
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition: transpose_vectors.hpp:212
half_t S
Definition: transpose_vectors.hpp:56
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition: transpose_vectors.hpp:60
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition: transpose_vectors.hpp:133
int8_t S
Definition: transpose_vectors.hpp:129
Definition: transpose_vectors.hpp:16
Definition: dtype_vector.hpp:10