/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/sweep_tile.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/sweep_tile.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/sweep_tile.hpp Source File
sweep_tile.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
13 
14 namespace ck_tile {
15 
16 // sweep over a span of a distribted tile and apply lambda function F
17 template <typename TileDistributedSpan_, // tile_distributed_span<...>
18  typename F // signature: F(tile_distributed_index<...>)
19  >
20 CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
21 {
22  using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
23 
24  static_ford<typename DstrSpan::Impl>{}([&](auto dstr_idx_impl) {
25  constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl);
26 
27  f(dstr_idx);
28  });
29 }
30 
31 // unpacked span, this version support span with unpack(multi-arg) functor
32 //
33 template <
34  typename TileDistributedSpan_, // tile_distributed_span<...>
35  typename F, // signature: F(tile_distributed_index<...>)
36  typename Unpacks = typename uniform_sequence_gen<TileDistributedSpan_::Impl::size(), 1>::type>
37 CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F& f, Unpacks = {})
38 {
39  using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
40 
41  static_uford<typename DstrSpan::Impl, Unpacks>{}(
42  [&](auto... dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)...); });
43 }
44 
45 namespace impl {
46 
47 template <typename, typename, typename>
49 
50 template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
51 struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
52 {
53  CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
54  {
55  constexpr auto spans = DistributedTensor::get_distributed_spans();
56  constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
57  constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
58  constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
59  return y_unpacks;
60  }
62  {
63  constexpr auto spans = DistributedTensor::get_distributed_spans();
64  constexpr auto u =
65  static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
66  return u.get_num_of_access() *
67  sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
68  .get_num_of_access();
69  }
70  template <typename F, typename SpanIdx>
71  CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
72  {
73  constexpr auto spans = DistributedTensor::get_distributed_spans();
74 
76  spans[number<I>{}],
77  [&](auto... i_idx) {
78  const auto next_span_idx = embed_tuples(
79  [&](auto si) { return make_tuple(concat_tuple(si, make_tuple(i_idx))...); },
80  span_idx);
81  sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
82  f, next_span_idx);
83  },
84  get_y_unpacks());
85  }
86  template <typename F, typename SpanIdx, index_t i_access>
87  CK_TILE_HOST_DEVICE constexpr void
88  operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
89  {
90  constexpr auto spans = DistributedTensor::get_distributed_spans();
91  constexpr auto u =
92  static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
93  constexpr auto access_stride =
94  sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
95  .get_num_of_access();
96  constexpr auto curr_i_access = number<i_access / access_stride>{};
97  constexpr auto next_i_access = number<i_access % access_stride>{};
98  u(
99  [&](auto... i_idx) {
100  const auto next_span_idx = embed_tuples(
101  [&](auto si) {
102  return make_tuple(concat_tuple(
104  },
105  span_idx);
106  sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
107  f, next_span_idx, next_i_access);
108  },
109  curr_i_access);
110  }
111 };
112 
113 template <typename DistributedTensor, typename UnpacksPerXDim>
114 struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<>>
115 {
116  CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const { return 1; }
117  template <typename F, typename SpanIdx>
118  CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
119  {
120  unpack(f, span_idx);
121  }
122  template <typename F, typename SpanIdx, index_t i_access>
123  CK_TILE_HOST_DEVICE constexpr void
124  operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
125  {
126  unpack(f, span_idx);
127  }
128 };
129 
130 template <typename, typename, typename>
132 
133 // TODO: support empty tuple to remove this "entry-point" like function
134 template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
135 struct sweep_tile_impl_0<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
136 {
137  CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
138  {
139  constexpr auto spans = DistributedTensor::get_distributed_spans();
140  constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
141  constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
142  constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
143  return y_unpacks;
144  }
146  {
147  constexpr auto spans = DistributedTensor::get_distributed_spans();
148  constexpr auto u =
149  static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
150  return u.get_num_of_access() *
151  sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
152  .get_num_of_access();
153  }
154  template <typename F>
155  CK_TILE_HOST_DEVICE constexpr void operator()(const F& f) const
156  {
157  constexpr auto spans = DistributedTensor::get_distributed_spans();
159  spans[number<I>{}],
160  [&](auto... i_idx) {
161  constexpr auto next_span_idx = make_tuple(make_tuple(i_idx)...);
162  sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
163  f, next_span_idx);
164  },
165  get_y_unpacks());
166  }
167  template <typename F, index_t i_access>
168  CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, number<i_access>) const
169  {
170  constexpr auto spans = DistributedTensor::get_distributed_spans();
171  constexpr auto u =
172  static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
173  constexpr auto access_stride =
174  sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
175  .get_num_of_access();
176  constexpr auto curr_i_access = number<i_access / access_stride>{};
177  constexpr auto next_i_access = number<i_access % access_stride>{};
178  u(
179  [&](auto... i_idx) {
180  constexpr auto next_span_idx =
182  sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
183  f, next_span_idx, next_i_access);
184  },
185  curr_i_access);
186  }
187 };
188 
189 } // namespace impl
190 
191 /*
192  * Enhanced sweep-tile utility, can control unpacks along each X-dim
193  * the lambda function argument is the distributed-idx, which can directly
194  * plugged into the distributed tensor as setter/getter
195  *
196  * e.g. below function, y with the type DistributedTensor, r is row scale
197  *
198  * // sweep tile 1 by 1
199  * sweep_tile<DistributedTensor>([&](auto idx) {
200  * constexpr auto row_id = make_tuple(idx[number<0>{}]);
201  * y(idx) = y(idx) * r(row_id);
202  * });
203  *
204  * // sweep tile with 2 pixel from last dim each function call
205  * sweep_tile<DistributedTensor>(
206  * [&](auto idx_0, auto idx_1) {
207  * constexpr auto row_id = make_tuple(idx_0[number<0>{}]);
208  * y(idx_0) = y(idx_0) * r(row_id);
209  * y(idx_1) = y(idx_1) * r(row_id);
210  * },
211  * sequence<1, 2>{});
212  *
213  * // sweep tile with 2x2 pixel each function call
214  * sweep_tile<DistributedTensor>(
215  * [&](auto idx_00, auto idx_01, auto idx_10, auto idx_11) {
216  * constexpr auto row_id0 = make_tuple(idx_00[number<0>{}]);
217  * constexpr auto row_id1 = make_tuple(idx_10[number<0>{}]);
218  * y(idx_00) = y(idx_00) * r(row_id0);
219  * y(idx_01) = y(idx_01) * r(row_id0);
220  * y(idx_10) = y(idx_10) * r(row_id1);
221  * y(idx_11) = y(idx_11) * r(row_id1);
222  * },
223  * sequence<2, 2>{});
224  *
225  * TODO: do we need constexpr? lambda function could be non-constexpr
226  */
227 template <typename DistributedTensor,
228  typename F,
229  typename UnpacksPerXDim =
230  typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
231 CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F& f, UnpacksPerXDim = {})
232 {
233  constexpr auto spans = DistributedTensor::get_distributed_spans();
234 
235  impl::sweep_tile_impl_0<DistributedTensor,
236  UnpacksPerXDim,
237  typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(f);
238 }
239 
240 template <typename DistributedTensor,
241  typename F,
242  typename UnpacksPerXDim =
243  typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
244 CK_TILE_HOST_DEVICE constexpr void
245 sweep_tile(const DistributedTensor&, const F& f, UnpacksPerXDim = {})
246 {
247  sweep_tile<DistributedTensor, F, UnpacksPerXDim>(f, UnpacksPerXDim{});
248 }
249 
250 /*
251  * construct a sweep tile instance, which support issue the lambda one by one
252  * Note that this struct will hold the lambda functor, but will not hold the distributed tensor
253  * the functionality is the same as sweep_tile()
254  */
255 template <typename DistributedTensor_,
256  typename F_,
257  typename UnpacksPerXDim_ =
258  typename uniform_sequence_gen<DistributedTensor_::get_num_of_dimension(), 1>::type>
260 {
264 
267  : f(f_)
268  {
269  }
271  {
272  constexpr auto spans = DistributedTensor::get_distributed_spans();
273  constexpr auto tmp =
276  typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{};
277  return tmp.get_num_of_access();
278  }
279 
281  {
282  sweep_tile<DistributedTensor>(f, UnpacksPerXDim{});
283  }
284 
285  template <index_t i_access>
287  {
288  constexpr auto spans = DistributedTensor::get_distributed_spans();
289 
292  typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(
293  f, number<i_access>{});
294  }
295  F f;
296 };
297 
298 // partial deduction is not allowed
299 // template <typename T, typename F, typename U>
300 // CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
301 
302 // deduction guide
303 template <typename T,
304  typename F,
305  typename U = typename uniform_sequence_gen<T::get_num_of_dimension(), 1>::type>
306 CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {}) -> tile_sweeper<T, F, U>;
307 
308 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE_EXTERN
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_tile_distributed_index(sequence< Is... >)
Definition: tile_distribution.hpp:59
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F &f, Unpacks={})
Definition: sweep_tile.hpp:37
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto embed_tuples(F f, const X &x)
Definition: tuple.hpp:546
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T &, const F &, U={}) -> tile_sweeper< T, F, U >
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto get_y_unpacks_from_x_unpacks(YLengths, number< XUnpacks >)
Definition: static_distributed_tensor.hpp:197
constexpr CK_TILE_HOST_DEVICE auto unpack(F &&f, X &&x)
Definition: functional.hpp:200
constexpr CK_TILE_HOST_DEVICE auto concat_tuple(const tuple< X... > &tx, const tuple< Y... > &ty)
Definition: tuple.hpp:453
constexpr CK_TILE_HOST_DEVICE void sweep_tile(const F &f, UnpacksPerXDim={})
Definition: sweep_tile.hpp:231
Definition: sequence.hpp:284
Definition: integral_constant.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_num_of_access() const
Definition: sweep_tile.hpp:116
constexpr CK_TILE_HOST_DEVICE void operator()(const F &f, const SpanIdx &span_idx) const
Definition: sweep_tile.hpp:118
constexpr CK_TILE_HOST_DEVICE void operator()(const F &f, const SpanIdx &span_idx, number< i_access >) const
Definition: sweep_tile.hpp:124
constexpr CK_TILE_HOST_DEVICE void operator()(const F &f, const SpanIdx &span_idx, number< i_access >) const
Definition: sweep_tile.hpp:88
constexpr CK_TILE_HOST_DEVICE index_t get_num_of_access() const
Definition: sweep_tile.hpp:61
constexpr CK_TILE_HOST_DEVICE auto get_y_unpacks() const
Definition: sweep_tile.hpp:53
constexpr CK_TILE_HOST_DEVICE void operator()(const F &f, const SpanIdx &span_idx) const
Definition: sweep_tile.hpp:71
constexpr CK_TILE_HOST_DEVICE auto get_y_unpacks() const
Definition: sweep_tile.hpp:137
constexpr CK_TILE_HOST_DEVICE index_t get_num_of_access() const
Definition: sweep_tile.hpp:145
constexpr CK_TILE_HOST_DEVICE void operator()(const F &f, number< i_access >) const
Definition: sweep_tile.hpp:168
constexpr CK_TILE_HOST_DEVICE void operator()(const F &f) const
Definition: sweep_tile.hpp:155
Definition: sweep_tile.hpp:131
Definition: sweep_tile.hpp:48
Definition: sequence.hpp:49
Definition: functional.hpp:141
Definition: functional_with_tuple.hpp:129
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_access()
Definition: functional_with_tuple.hpp:141
Definition: sweep_tile.hpp:260
remove_cvref_t< DistributedTensor_ > DistributedTensor
Definition: sweep_tile.hpp:261
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_access()
Definition: sweep_tile.hpp:270
remove_cvref_t< F_ > F
Definition: sweep_tile.hpp:262
remove_cvref_t< UnpacksPerXDim_ > UnpacksPerXDim
Definition: sweep_tile.hpp:263
CK_TILE_HOST_DEVICE void operator()(number< i_access >) const
Definition: sweep_tile.hpp:286
CK_TILE_HOST_DEVICE tile_sweeper(const DistributedTensor &, const F &f_, UnpacksPerXDim={})
Definition: sweep_tile.hpp:266
F f
Definition: sweep_tile.hpp:295
CK_TILE_HOST_DEVICE void operator()() const
Definition: sweep_tile.hpp:280
CK_TILE_HOST_DEVICE tile_sweeper(const F &f_, UnpacksPerXDim={})
Definition: sweep_tile.hpp:265
Definition: sequence.hpp:311