/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/transpose_tile.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/transpose_tile.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/transpose_tile.hpp Source File
transpose_tile.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
19 
20 namespace ck_tile {
21 namespace detail {
22 
23 template <typename OutTensor, typename InTensor>
25  const InTensor& in_tensor)
26 {
27  constexpr auto I0 = number<0>{};
28 
29  static_assert(std::is_same_v<typename InTensor::DataType, typename OutTensor::DataType>,
30  "Data type for InTensor and OutTensor must be the same!");
31 
32  using DataType = typename InTensor::DataType;
33 
34  constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
35  constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
36 
37  // y_dim_out_to_in
38  // For swapped Hs tile case I need only get_rh_minor_to_y
39  // since rh_major are already swapped due to swapped Hs.
40  constexpr auto get_rh_minor_to_y = [](auto dstr_tensor) {
41  using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
42 
43  map<index_t, index_t> rh_minor_to_y_;
44 
46  constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
47 
48  rh_minor_to_y_(rh_minor) = i;
49  });
50 
51  return rh_minor_to_y_;
52  };
53 
54  // In swapped Hs case <Y,X> -> <X,Y> tile
55  // we have same rh_major, but reversed rh_minor!
56  constexpr auto rh_minor_to_y_in = get_rh_minor_to_y(InTensor{});
57  constexpr auto rh_minor_to_y_out = get_rh_minor_to_y(OutTensor{});
58 
59  // Is this really needed?? Should we have simple reverse here??
60  constexpr auto y_dim_out_to_in = [&] {
61  map<index_t, index_t> y_dim_out_to_in_;
62 
63  for(const auto& [rh_minor, y_out] : rh_minor_to_y_out)
64  {
65  y_dim_out_to_in_(y_out) = rh_minor_to_y_in[rh_minor];
66  }
67 
68  return y_dim_out_to_in_;
69  }();
70 
71  constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
72  constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
73 
74  // input and output vector dim in the order of input Y dims
75  constexpr index_t y_dim_vec_in = NDimY - 1;
76  constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
77 
78  // vector lengths
79  constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
80  constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
81 
82  // # of vectors
83  constexpr index_t num_vec_in = vec_length_out;
84  constexpr index_t num_vec_out = vec_length_in;
85 
86  // SFC
87  constexpr auto scalars_per_access_arr = generate_array(
88  [&](auto i) {
89  if constexpr(vec_length_in == 1)
90  return 1;
91  else
92  return (i == y_dim_vec_in || i == y_dim_vec_out) ? y_lengths[i] : 1;
93  },
94  number<NDimY>{});
95 
96  constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
97 
98  using SFC_Y = space_filling_curve<decltype(y_lengths),
100  decltype(scalars_per_access)>;
101 
102  constexpr index_t num_access = SFC_Y::get_num_of_access();
103 
104  static_assert(num_access > 0, "wrong! num_access should be larger than 0");
105 
106  if constexpr(num_vec_in == 1 || num_vec_out == 1)
107  {
108  // loop over SFC
109  static_for<0, num_access, 1>{}([&](auto iAccess) {
110  // data index [y0, y1, ...] in the order of input tensor
111  constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
112  constexpr auto idx_y_in =
113  generate_tuple([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
114  constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
115  static_assert(in_offset % vec_length_in == 0);
116  constexpr auto idx_y_out_tmp =
117  generate_array([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
118  constexpr auto idx_y_out =
119  container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
120  constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
121  if constexpr(vec_length_in == 1)
122  {
123 
124  out_tensor.get_thread_buffer()[number<out_offset>{}] =
125  in_tensor.get_thread_buffer()[number<in_offset>{}];
126  }
127  else
128  {
129  using Vec = array<DataType, vec_length_in>;
130  out_tensor.get_thread_buffer().template get_as<Vec>(
132  in_tensor.get_thread_buffer().template get_as<Vec>(
134  }
135  });
136  }
137  else
138  {
139  using InVec = array<DataType, vec_length_in>;
140  using OutVec = array<DataType, vec_length_out>;
141 
142  // in/out vectors to be transposed
145 
146  // loop over SFC and do transpose
147  static_for<0, num_access, 1>{}([&](auto iAccess) {
148  // data index [y0, y1, ...] in the order of input tensor
149  constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
150 
151  // get input vectors
152  static_for<0, num_vec_in, 1>{}([&](auto i) {
153  constexpr auto idx_y_in = generate_tuple(
154  [&](auto ii) {
155  return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
156  },
157  number<NDimY>{});
158 
159  constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
160  static_assert(in_offset % vec_length_in == 0);
161 
162  in_vectors(i).template get_as<InVec>()(I0) =
163  in_tensor.get_thread_buffer()
164  .template get_as<InVec>()[number<in_offset / vec_length_in>{}];
165  });
166 
167  // transpose
168  transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
169 
170  // set output vectors
171  static_for<0, num_vec_out, 1>{}([&](auto i) {
172  constexpr auto idx_y_out_tmp = generate_array(
173  [&](auto ii) {
174  return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii];
175  },
176  number<NDimY>{});
177 
178  constexpr auto idx_y_out =
179  container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
180 
181  constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
182  static_assert(out_offset % vec_length_out == 0);
183 
184  out_tensor.get_thread_buffer().template set_as<OutVec>(
186  out_vectors[i].template get_as<OutVec>()[I0]);
187  });
188  });
189  }
190 }
191 
192 } // namespace detail
193 
194 template <typename OutTensor, typename InTensor>
195 CK_TILE_DEVICE void transpose_tile2d(OutTensor& out, const InTensor& in)
196 {
197  using InDataType = typename InTensor::DataType;
198  using OutDataType = typename OutTensor::DataType;
199 
200  using InTileDistr = typename InTensor::StaticTileDistribution;
201  using OutTileDistr = typename OutTensor::StaticTileDistribution;
202 
203  using InDstrEncode = typename InTileDistr::DstrEncode;
204  using OutDstrEncode = typename OutTileDistr::DstrEncode;
205 
206  using InThreadTensorDesc = typename InTensor::ThreadTensorDesc;
207  using OutThreadTensorDesc = typename OutTensor::ThreadTensorDesc;
208 
209  // Ys:
210  constexpr auto in_thread_desc_lengths = InThreadTensorDesc{}.get_lengths();
211  constexpr auto out_thread_desc_lengths = OutThreadTensorDesc{}.get_lengths();
212 
213  // type convert
214  const auto in_tmp = [&]() {
215  if constexpr(std::is_same_v<OutDataType, InDataType>)
216  {
217  return in;
218  }
219  else
220  {
221  return tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
222  }
223  }();
224 
225  // Scenario where we switch from tile <Y, X> -> <X, Y> - only 2D tiles!
226  // we preserve Ps but swap Ys: <Y1, Y0> -> <Y0, Y1>
227  if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
228  InDstrEncode::hs_lengthss_ == tuple_reverse(OutDstrEncode::hs_lengthss_) &&
229  InDstrEncode::NDimY == OutDstrEncode::NDimY && InDstrEncode::NDimY == 2 &&
230  in_thread_desc_lengths == tuple_reverse(out_thread_desc_lengths))
231  // Any condition on Ps ??
232  // InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
233  // InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
234  {
236  }
237  else
238  {
239  static_assert(false, "Provided tensors could not be transposed!");
240  }
241 }
242 
243 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor &out_tensor, const InTensor &in_tensor)
Definition: transpose_tile.hpp:24
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
constexpr CK_TILE_HOST_DEVICE auto container_reorder_given_new2old(const array< TData, NSize > &old_array, sequence< IRs... >)
Definition: container_helper.hpp:39
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1112
int32_t index_t
Definition: integer.hpp:9
CK_TILE_DEVICE void transpose_tile2d(OutTensor &out, const InTensor &in)
Definition: transpose_tile.hpp:195
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1052
constexpr CK_TILE_HOST_DEVICE auto tuple_reverse(const tuple< Ts... > &t)
Definition: tuple.hpp:583
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:299
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: map.hpp:16
Definition: space_filling_curve.hpp:20
Definition: functional.hpp:43
Definition: debug.hpp:67
Definition: transpose_vectors.hpp:20
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10