23 template <
typename OutTensor,
typename InTensor>
28 using DataType =
typename InTensor::DataType;
30 constexpr
auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
31 constexpr
auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
34 constexpr
auto get_rh_major_minor_to_y = [](
auto dstr_tensor) {
35 using DstrEncode =
typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
40 constexpr
index_t rh_major = DstrEncode::ys_to_rhs_major_[i];
41 constexpr
index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
43 rh_major_minor_to_y_({rh_major, rh_minor}) = i;
46 return rh_major_minor_to_y_;
49 constexpr
auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
50 constexpr
auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
52 constexpr
auto y_dim_out_to_in = [&] {
55 for(
const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
57 y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
60 return y_dim_out_to_in_;
64 constexpr
index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
66 constexpr
auto y_lengths =
to_sequence(y_in_desc.get_lengths());
69 constexpr
index_t y_dim_vec_in = NDimY - 1;
70 constexpr
index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
73 constexpr
index_t vec_length_in = y_lengths[y_dim_vec_in];
74 constexpr
index_t vec_length_out = y_lengths[y_dim_vec_out];
77 constexpr
index_t num_vec_in = vec_length_out;
78 constexpr
index_t num_vec_out = vec_length_in;
88 [&](
auto i) {
return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
91 constexpr
auto scalars_per_access =
TO_SEQUENCE(scalars_per_access_arr, NDimY);
95 decltype(scalars_per_access)>;
97 constexpr
index_t num_access = SFC_Y::get_num_of_access();
99 static_assert(num_access > 0,
"wrong! num_access should be larger than 0");
108 constexpr
auto idx_y_start = SFC_Y::get_index(iAccess);
114 return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
118 constexpr
index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
119 static_assert(in_offset % vec_length_in == 0);
121 in_vectors(i).template get_as<InVec>()(I0) =
122 in_tensor.get_thread_buffer()
123 .template get_as<InVec>()[
number<in_offset / vec_length_in>{}];
133 return ii == y_dim_vec_in ?
static_cast<index_t>(idx_y_start[ii]) + i
134 :
static_cast<index_t>(idx_y_start[ii]);
138 constexpr
auto idx_y_out =
141 constexpr
index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
142 static_assert(out_offset % vec_length_out == 0);
144 out_tensor.get_thread_buffer().template set_as<OutVec>(
146 out_vectors[i].template get_as<OutVec>()[I0]);
153 template <
typename OutTensor,
typename InTensor>
156 using InDataType =
typename InTensor::DataType;
157 using OutDataType =
typename OutTensor::DataType;
159 using InDstrEncode =
typename InTensor::StaticTileDistribution::DstrEncode;
160 using OutDstrEncode =
typename OutTensor::StaticTileDistribution::DstrEncode;
166 if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
167 InDstrEncode::hs_lengthss_ == OutDstrEncode::hs_lengthss_ &&
168 InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
169 InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
170 InDstrEncode::NDimY == OutDstrEncode::NDimY)
176 static_assert(
false,
"The shuffle should always happen!");
#define CK_TILE_DEVICE
Definition: config.hpp:41
CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor &out_tensor, const InTensor &in_tensor)
Definition: shuffle_tile.hpp:24
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
constexpr CK_TILE_HOST_DEVICE auto container_reorder_given_new2old(const array< TData, NSize > &old_array, sequence< IRs... >)
Definition: container_helper.hpp:39
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1112
int32_t index_t
Definition: integer.hpp:9
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition: shuffle_tile.hpp:154
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1052
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:299
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: space_filling_curve.hpp:20
Definition: functional.hpp:43
Definition: transpose_vectors.hpp:20
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10