/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/transpose_vectors.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/transpose_vectors.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/transpose_vectors.hpp Source File
transpose_vectors.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/ck.hpp"
8 #include "data_type.hpp"
9 
10 namespace ck {
11 
12 template <typename S,
13  index_t NX,
14  index_t NY,
15  typename enable_if<is_scalar_type<S>::value, bool>::type = false>
17 
18 // transpose fp16 2x2
19 __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1)
20 {
21 #if 0
22  static constexpr auto I0 = Number<0>{};
23  static constexpr auto I1 = Number<1>{};
24 
25  const vector_type<half_t, 2> vx0{x0}, vx1{x1};
26  vector_type<half_t, 2> vy0, vy1;
27 
28  vy0.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I0];
29  vy0.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I0];
30 
31  vy1.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I1];
32  vy1.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I1];
33 
34  y0 = vy0.template AsType<half2_t>()[I0];
35  y1 = vy1.template AsType<half2_t>()[I0];
36 #else
37  constexpr int32_t m0 = 0x05040100;
38  constexpr int32_t m1 = 0x07060302;
39 
40  // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
41  // -- -- -- -- -- -- -- -- - - - -
42  // index 7 6 5 4 3 2 1 0 33 77 44 88
43  // index is reversed because of little endianness (least significant bits first)
44  y0 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0));
45  y1 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1));
46 #endif
47 }
48 
49 template <index_t NX, index_t NY>
50 struct transpose_vectors<half_t, NX, NY>
51 {
52  // we got [NY * NX] amount of S data to be transposed
53  static constexpr index_t s_per_x = NY;
54  static constexpr index_t s_per_y = NX;
55 
56  using S = half_t;
59 
60  __device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
62  {
63  static constexpr auto I1 = Number<1>{};
64  static constexpr auto I2 = Number<2>{};
65 
66  static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
67 
68  // loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
69  static_for<0, NY, 2>{}([&](auto iy) {
70  static_for<0, NX, 2>{}([&](auto ix) {
71  // reference to 2 half2_t data from vx_tuple
72  const auto& x_s2_0 = vx_tuple[ix].template AsType<half2_t>()[iy / I2];
73  const auto& x_s2_1 = vx_tuple[ix + I1].template AsType<half2_t>()[iy / I2];
74 
75  // reference to 2 half2_t data from vy_tuple
76  auto& y_s2_0 = vy_tuple(iy).template AsType<half2_t>()(ix / I2);
77  auto& y_s2_1 = vy_tuple(iy + I1).template AsType<half2_t>()(ix / I2);
78 
79  // transpose
80  transpose_fp16_2x2(x_s2_0, x_s2_1, y_s2_0, y_s2_1);
81  });
82  });
83  }
84 };
85 
86 // transpose int8 4x4
87 __device__ void transpose_int8_4x4(const int8x4_t& x0,
88  const int8x4_t& x1,
89  const int8x4_t& x2,
90  const int8x4_t& x3,
91  int8x4_t& y0,
92  int8x4_t& y1,
93  int8x4_t& y2,
94  int8x4_t& y3)
95 {
96  int32_t t0, t1;
97  int32_t z0, z1, z2, z3;
98  constexpr int32_t m0 = 0x05010400;
99  constexpr int32_t m1 = 0x05040100;
100  constexpr int32_t m2 = 0x07060302;
101  constexpr int32_t m3 = 0x07030602;
102 
103  // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
104  // -- -- -- -- -- -- -- -- - - - -
105  // index 7 6 5 4 3 2 1 0 33 77 44 88
106  // index is reversed because of little endianness (least significant bits first)
107  t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
108  t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
109  z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
110  z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
111  t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
112  t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
113  z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
114  z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
115 
116  y0 = bit_cast<int8x4_t>(z0);
117  y1 = bit_cast<int8x4_t>(z1);
118  y2 = bit_cast<int8x4_t>(z2);
119  y3 = bit_cast<int8x4_t>(z3);
120 }
121 
122 template <index_t NX, index_t NY>
123 struct transpose_vectors<int8_t, NX, NY>
124 {
125  // we got [NY * NX] amount of S data to be transposed
126  static constexpr index_t s_per_x = NY;
127  static constexpr index_t s_per_y = NX;
128 
129  using S = int8_t;
132 
133  __device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
135  {
136  static constexpr auto I1 = Number<1>{};
137  static constexpr auto I2 = Number<2>{};
138  static constexpr auto I3 = Number<3>{};
139  static constexpr auto I4 = Number<4>{};
140 
141  static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!");
142 
143  // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
144  static_for<0, NY, 4>{}([&](auto iy) {
145  static_for<0, NX, 4>{}([&](auto ix) {
146  // reference to 4 int8 data from vx_tuple
147  const auto& x_s4_0 = vx_tuple[ix].template AsType<int8x4_t>()[iy / I4];
148  const auto& x_s4_1 = vx_tuple[ix + I1].template AsType<int8x4_t>()[iy / I4];
149  const auto& x_s4_2 = vx_tuple[ix + I2].template AsType<int8x4_t>()[iy / I4];
150  const auto& x_s4_3 = vx_tuple[ix + I3].template AsType<int8x4_t>()[iy / I4];
151 
152  // reference to 4 int8 data from vy_tuple
153  auto& y_s4_0 = vy_tuple(iy).template AsType<int8x4_t>()(ix / I4);
154  auto& y_s4_1 = vy_tuple(iy + I1).template AsType<int8x4_t>()(ix / I4);
155  auto& y_s4_2 = vy_tuple(iy + I2).template AsType<int8x4_t>()(ix / I4);
156  auto& y_s4_3 = vy_tuple(iy + I3).template AsType<int8x4_t>()(ix / I4);
157 
158  // transpose
159  transpose_int8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3);
160  });
161  });
162  }
163 };
164 
165 // transpose f8 4x4
166 __device__ void transpose_f8_4x4(const f8x4_t& x0,
167  const f8x4_t& x1,
168  const f8x4_t& x2,
169  const f8x4_t& x3,
170  f8x4_t& y0,
171  f8x4_t& y1,
172  f8x4_t& y2,
173  f8x4_t& y3)
174 {
175  int32_t t0, t1;
176  int32_t z0, z1, z2, z3;
177  constexpr int32_t m0 = 0x05010400;
178  constexpr int32_t m1 = 0x05040100;
179  constexpr int32_t m2 = 0x07060302;
180  constexpr int32_t m3 = 0x07030602;
181 
182  // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
183  // -- -- -- -- -- -- -- -- - - - -
184  // index 7 6 5 4 3 2 1 0 33 77 44 88
185  // index is reversed because of little endianness (least significant bits first)
186  t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
187  t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
188  z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
189  z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
190  t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
191  t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
192  z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
193  z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
194 
195  y0 = bit_cast<f8x4_t>(z0);
196  y1 = bit_cast<f8x4_t>(z1);
197  y2 = bit_cast<f8x4_t>(z2);
198  y3 = bit_cast<f8x4_t>(z3);
199 }
200 
201 template <index_t NX, index_t NY>
202 struct transpose_vectors<f8_t, NX, NY>
203 {
204  // we got [NY * NX] amount of S data to be transposed
205  static constexpr index_t s_per_x = NY;
206  static constexpr index_t s_per_y = NX;
207 
208  using S = f8_t;
211 
212  __device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
214  {
215  static constexpr auto I1 = Number<1>{};
216  static constexpr auto I2 = Number<2>{};
217  static constexpr auto I3 = Number<3>{};
218  static constexpr auto I4 = Number<4>{};
219 
220  static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!");
221 
222  // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
223  static_for<0, NY, 4>{}([&](auto iy) {
224  static_for<0, NX, 4>{}([&](auto ix) {
225  // reference to 4 f8 data from vx_tuple
226  const auto& x_s4_0 = vx_tuple[ix].template AsType<f8x4_t>()[iy / I4];
227  const auto& x_s4_1 = vx_tuple[ix + I1].template AsType<f8x4_t>()[iy / I4];
228  const auto& x_s4_2 = vx_tuple[ix + I2].template AsType<f8x4_t>()[iy / I4];
229  const auto& x_s4_3 = vx_tuple[ix + I3].template AsType<f8x4_t>()[iy / I4];
230 
231  // reference to 4 f8 data from vy_tuple
232  auto& y_s4_0 = vy_tuple(iy).template AsType<f8x4_t>()(ix / I4);
233  auto& y_s4_1 = vy_tuple(iy + I1).template AsType<f8x4_t>()(ix / I4);
234  auto& y_s4_2 = vy_tuple(iy + I2).template AsType<f8x4_t>()(ix / I4);
235  auto& y_s4_3 = vy_tuple(iy + I3).template AsType<f8x4_t>()(ix / I4);
236 
237  // transpose
238  transpose_f8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3);
239  });
240  });
241  }
242 };
243 
244 } // namespace ck
Definition: ck.hpp:267
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
_Float16 half_t
Definition: data_type.hpp:30
__device__ void transpose_f8_4x4(const f8x4_t &x0, const f8x4_t &x1, const f8x4_t &x2, const f8x4_t &x3, f8x4_t &y0, f8x4_t &y1, f8x4_t &y2, f8x4_t &y3)
Definition: transpose_vectors.hpp:166
typename vector_type< half_t, 2 >::type half2_t
Definition: dtype_vector.hpp:2139
int32_t index_t
Definition: ck.hpp:298
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2163
__device__ void transpose_int8_4x4(const int8x4_t &x0, const int8x4_t &x1, const int8x4_t &x2, const int8x4_t &x3, int8x4_t &y0, int8x4_t &y1, int8x4_t &y2, int8x4_t &y3)
Definition: transpose_vectors.hpp:87
__device__ void transpose_fp16_2x2(const half2_t &x0, const half2_t &x1, half2_t &y0, half2_t &y1)
Definition: transpose_vectors.hpp:19
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
f8_t S
Definition: transpose_vectors.hpp:208
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition: transpose_vectors.hpp:212
half_t S
Definition: transpose_vectors.hpp:56
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition: transpose_vectors.hpp:60
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition: transpose_vectors.hpp:133
int8_t S
Definition: transpose_vectors.hpp:129
Definition: transpose_vectors.hpp:16
Definition: dtype_vector.hpp:10