/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/shuffle_tile.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/shuffle_tile.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/shuffle_tile.hpp Source File
shuffle_tile.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
19 
20 namespace ck_tile {
21 namespace detail {
22 
23 template <typename OutTensor, typename InTensor>
24 CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InTensor& in_tensor)
25 {
26  constexpr auto I0 = number<0>{};
27 
28  using DataType = typename InTensor::DataType;
29 
30  constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
31  constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
32 
33  // y_dim_out_to_in
34  constexpr auto get_rh_major_minor_to_y = [](auto dstr_tensor) {
35  using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
36 
37  map<array<index_t, 2>, index_t> rh_major_minor_to_y_;
38 
40  constexpr index_t rh_major = DstrEncode::ys_to_rhs_major_[i];
41  constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
42 
43  rh_major_minor_to_y_({rh_major, rh_minor}) = i;
44  });
45 
46  return rh_major_minor_to_y_;
47  };
48 
49  constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
50  constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
51 
52  constexpr auto y_dim_out_to_in = [&] {
53  map<index_t, index_t> y_dim_out_to_in_;
54 
55  for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
56  {
57  y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
58  }
59 
60  return y_dim_out_to_in_;
61  }();
62 
63  //
64  constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
65 
66  constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
67 
68  // input and output vector dim in the order of input Y dims
69  constexpr index_t y_dim_vec_in = NDimY - 1;
70  constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
71 
72  // vector lengths
73  constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
74  constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
75 
76  // # of vectors
77  constexpr index_t num_vec_in = vec_length_out;
78  constexpr index_t num_vec_out = vec_length_in;
79 
80  using InVec = array<DataType, vec_length_in>;
81  using OutVec = array<DataType, vec_length_out>;
82 
83  // using InVec = typename InVec::type;
84  // using OutVec = typename OutVec::type;
85 
86  // SFC
87  constexpr auto scalars_per_access_arr = generate_array(
88  [&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
89  number<NDimY>{});
90 
91  constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
92 
93  using SFC_Y = space_filling_curve<decltype(y_lengths),
95  decltype(scalars_per_access)>;
96 
97  constexpr index_t num_access = SFC_Y::get_num_of_access();
98 
99  static_assert(num_access > 0, "wrong! num_access should be larger than 0");
100 
101  // in/out vectors to be transposed
104 
105  // loop over SFC and do transpose
106  static_for<0, num_access, 1>{}([&](auto iAccess) {
107  // data index [y0, y1, ...] in the order of input tensor
108  constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
109 
110  // get input vectors
111  static_for<0, num_vec_in, 1>{}([&](auto i) {
112  constexpr auto idx_y_in = generate_tuple(
113  [&](auto ii) {
114  return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
115  },
116  number<NDimY>{});
117 
118  constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
119  static_assert(in_offset % vec_length_in == 0);
120 
121  in_vectors(i).template get_as<InVec>()(I0) =
122  in_tensor.get_thread_buffer()
123  .template get_as<InVec>()[number<in_offset / vec_length_in>{}];
124  });
125 
126  // transpose
127  transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
128 
129  // set output vectors
130  static_for<0, num_vec_out, 1>{}([&](auto i) {
131  constexpr auto idx_y_out_tmp = generate_array(
132  [&](auto ii) {
133  return ii == y_dim_vec_in ? static_cast<index_t>(idx_y_start[ii]) + i
134  : static_cast<index_t>(idx_y_start[ii]);
135  },
136  number<NDimY>{});
137 
138  constexpr auto idx_y_out =
139  container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
140 
141  constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
142  static_assert(out_offset % vec_length_out == 0);
143 
144  out_tensor.get_thread_buffer().template set_as<OutVec>(
146  out_vectors[i].template get_as<OutVec>()[I0]);
147  });
148  });
149 }
150 
151 } // namespace detail
152 
153 template <typename OutTensor, typename InTensor>
154 CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
155 {
156  using InDataType = typename InTensor::DataType;
157  using OutDataType = typename OutTensor::DataType;
158 
159  using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
160  using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
161 
162  // type convert
163  const auto in_tmp = tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
164 
165  // shuffle
166  if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
167  InDstrEncode::hs_lengthss_ == OutDstrEncode::hs_lengthss_ &&
168  InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
169  InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
170  InDstrEncode::NDimY == OutDstrEncode::NDimY)
171  {
173  }
174  else
175  {
176  static_assert(false, "The shuffle should always happen!");
177  }
178 }
179 
180 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor &out_tensor, const InTensor &in_tensor)
Definition: shuffle_tile.hpp:24
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
constexpr CK_TILE_HOST_DEVICE auto container_reorder_given_new2old(const array< TData, NSize > &old_array, sequence< IRs... >)
Definition: container_helper.hpp:39
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1112
int32_t index_t
Definition: integer.hpp:9
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition: shuffle_tile.hpp:154
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1052
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:299
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: map.hpp:16
Definition: space_filling_curve.hpp:20
Definition: functional.hpp:43
Definition: debug.hpp:67
Definition: transpose_vectors.hpp:20
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10