23 template <
typename OutTensor,
typename InTensor>
25 const InTensor& in_tensor)
29 static_assert(std::is_same_v<typename InTensor::DataType, typename OutTensor::DataType>,
30 "Data type for InTensor and OutTensor must be the same!");
32 using DataType =
typename InTensor::DataType;
34 constexpr
auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
35 constexpr
auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
39 constexpr
index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
41 constexpr
auto y_dim_out_to_in = [&] {
46 return y_dim_out_to_in_;
49 constexpr
auto y_lengths =
to_sequence(y_in_desc.get_lengths());
52 constexpr
index_t y_dim_vec_in = NDimY - 1;
53 constexpr
index_t y_dim_vec_out = 0;
56 constexpr
index_t vec_length_in = y_lengths[y_dim_vec_in];
57 constexpr
index_t vec_length_out = y_lengths[y_dim_vec_out];
60 constexpr
index_t num_vec_in = vec_length_out;
61 constexpr
index_t num_vec_out = vec_length_in;
66 if constexpr(vec_length_in == 1)
69 return (i == y_dim_vec_in || i == y_dim_vec_out) ? y_lengths[i] : 1;
73 constexpr
auto scalars_per_access =
TO_SEQUENCE(scalars_per_access_arr, NDimY);
77 decltype(scalars_per_access)>;
79 constexpr
index_t num_access = SFC_Y::get_num_of_access();
81 static_assert(num_access > 0,
"wrong! num_access should be larger than 0");
83 if constexpr(num_vec_in == 1 || num_vec_out == 1)
88 constexpr
auto idx_y_start = SFC_Y::get_index(iAccess);
89 constexpr
auto idx_y_in =
91 constexpr
index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
92 static_assert(in_offset % vec_length_in == 0);
93 constexpr
auto idx_y_out_tmp =
95 constexpr
auto idx_y_out =
97 constexpr
index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
98 if constexpr(vec_length_in == 1)
107 out_tensor.get_thread_buffer().template get_as<Vec>(
109 in_tensor.get_thread_buffer().template get_as<Vec>(
126 constexpr
auto idx_y_start = SFC_Y::get_index(iAccess);
132 return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
136 constexpr
index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
137 static_assert(in_offset % vec_length_in == 0);
139 in_vectors(i).template get_as<InVec>()(I0) =
140 in_tensor.get_thread_buffer()
141 .template get_as<InVec>()[
number<in_offset / vec_length_in>{}];
151 return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii];
155 constexpr
auto idx_y_out =
158 constexpr
index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
159 static_assert(out_offset % vec_length_out == 0);
161 out_tensor.get_thread_buffer().template set_as<OutVec>(
163 out_vectors[i].template get_as<OutVec>()[I0]);
171 template <
typename OutTensor,
typename InTensor>
174 using InDataType =
typename InTensor::DataType;
175 using OutDataType =
typename OutTensor::DataType;
177 using InTileDistr =
typename InTensor::StaticTileDistribution;
178 using OutTileDistr =
typename OutTensor::StaticTileDistribution;
180 using InDstrEncode =
typename InTileDistr::DstrEncode;
181 using OutDstrEncode =
typename OutTileDistr::DstrEncode;
183 using InThreadTensorDesc =
typename InTensor::ThreadTensorDesc;
184 using OutThreadTensorDesc =
typename OutTensor::ThreadTensorDesc;
187 constexpr
auto in_thread_desc_lengths = InThreadTensorDesc{}.get_lengths();
188 constexpr
auto out_thread_desc_lengths = OutThreadTensorDesc{}.get_lengths();
191 const auto in_tmp = [&]() {
192 if constexpr(std::is_same_v<OutDataType, InDataType>)
204 if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
205 InDstrEncode::hs_lengthss_ ==
tuple_reverse(OutDstrEncode::hs_lengthss_) &&
206 InDstrEncode::NDimY == OutDstrEncode::NDimY && InDstrEncode::NDimY == 2 &&
207 in_thread_desc_lengths ==
tuple_reverse(out_thread_desc_lengths))
216 static_assert(
false,
"Provided tensors could not be transposed!");
#define CK_TILE_DEVICE
Definition: config.hpp:45
CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor &out_tensor, const InTensor &in_tensor)
Definition: transpose_tile.hpp:24
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
constexpr CK_TILE_HOST_DEVICE auto container_reorder_given_new2old(const array< TData, NSize > &old_array, sequence< IRs... >)
Definition: container_helper.hpp:39
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1126
int32_t index_t
Definition: integer.hpp:9
CK_TILE_DEVICE void transpose_tile2d(OutTensor &out, const InTensor &in)
Definition: transpose_tile.hpp:172
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1066
constexpr CK_TILE_HOST_DEVICE auto tuple_reverse(const tuple< Ts... > &t)
Definition: tuple.hpp:583
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:313
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: space_filling_curve.hpp:20
Definition: functional.hpp:43
Definition: transpose_vectors.hpp:20
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10