/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/container_helper.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/container_helper.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/container_helper.hpp Source File
container_helper.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #ifndef CK_CONTAINER_HELPER_HPP
5 #define CK_CONTAINER_HELPER_HPP
6 
7 #include "sequence.hpp"
8 #include "sequence_helper.hpp"
9 #include "array.hpp"
10 #include "tuple.hpp"
11 #include "tuple_helper.hpp"
14 
15 namespace ck {
16 
17 template <typename TData, index_t NSize>
18 __host__ __device__ constexpr auto container_push_back(const Array<TData, NSize>& a, const TData& x)
19 {
21 
22  static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; });
23 
24  r(Number<NSize>{}) = x;
25 
26  return r;
27 }
28 
29 template <typename... Ts, typename T>
30 __host__ __device__ constexpr auto container_push_front(const Tuple<Ts...>& a, const T& x)
31 {
32  return container_concat(make_tuple(x), a);
33 }
34 
35 template <typename... Ts, typename T>
36 __host__ __device__ constexpr auto container_push_back(const Tuple<Ts...>& a, const T& x)
37 {
38  return container_concat(a, make_tuple(x));
39 }
40 
41 template <typename TData, index_t NSize, index_t... IRs>
42 __host__ __device__ constexpr auto
44 {
45  static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
46 
47  static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
48 
49  return make_array(old_array[Number<IRs>{}]...);
50 }
51 
52 template <typename TData, index_t NSize, index_t... IRs>
53 __host__ __device__ constexpr auto
55 {
57  old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
58 }
59 
60 template <typename... Ts, index_t... IRs>
61 __host__ __device__ constexpr auto container_reorder_given_new2old(const Tuple<Ts...>& old_tuple,
62  Sequence<IRs...> /*new2old*/)
63 {
64  static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
65 
66  static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
67 
68  return make_tuple(old_tuple[Number<IRs>{}]...);
69 }
70 
71 template <typename... Ts, index_t... IRs>
72 __host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple<Ts...>& old_tuple,
73  Sequence<IRs...> old2new)
74 {
76  old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
77 }
78 
79 template <index_t... Is, index_t... IRs>
80 __host__ __device__ constexpr auto container_reorder_given_new2old(Sequence<Is...> /* old_seq */,
81  Sequence<IRs...> /*new2old*/)
82 {
83  static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
84 
85  static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
86 
88 }
89 
90 template <index_t... Is, index_t... IRs>
91 __host__ __device__ constexpr auto container_reorder_given_old2new(Sequence<Is...> old_seq,
92  Sequence<IRs...> /* old2new */)
93 {
94  static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
95 
96  static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
97 
98  constexpr auto new2old = typename sequence_map_inverse<Sequence<IRs...>>::type{};
99 
100  return container_reorder_given_new2old(old_seq, new2old);
101 }
102 
103 #if !CK_WORKAROUND_SWDEV_275126
104 // rocm-4.1 compiler would crash for recursive lambda
105 template <typename Container,
106  typename Reduce,
107  typename Init,
108  index_t IBegin = 0,
109  index_t IEnd = Container::Size(),
110  index_t IStep = 1>
111 __host__ __device__ constexpr auto container_reduce(const Container& x,
112  Reduce reduce,
113  Init init,
115  Number<IEnd> = Number<Container::Size()>{},
116  Number<IStep> = Number<1>{})
117 {
118  static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
119 
120  // f is recursive function, fs is a dummy of f
121  // i is index, y_old is current scan, r_old is current reduction
122  auto f = [&](auto fs, auto i, auto r_old) {
123  auto r_new = reduce(x[i], r_old);
124 
125  if constexpr(i.value < IEnd - IStep)
126  {
127  // recursively call f/fs
128  return fs(fs, i + Number<IStep>{}, r_new);
129  }
130  else
131  {
132  return r_new;
133  }
134  };
135 
136  // start recursion
137  return f(f, Number<IBegin>{}, init);
138 }
139 #else
140 // i is index, y_old is current scan, r_old is current reduction
141 template <typename Container,
142  typename Reduce,
143  typename ROld,
144  index_t I,
145  index_t IEnd,
146  index_t IStep>
147 __host__ __device__ constexpr auto container_reduce_impl(
148  const Container& x, Reduce reduce, ROld r_old, Number<I> i, Number<IEnd>, Number<IStep>)
149 {
150  auto r_new = reduce(x[i], r_old);
151 
152  if constexpr(i.value < IEnd - IStep)
153  {
154  return container_reduce_impl(
155  x, reduce, r_new, i + Number<IStep>{}, Number<IEnd>{}, Number<IStep>{});
156  }
157  else
158  {
159  return r_new;
160  }
161 }
162 
163 // rocm-4.1 compiler would crash for recursive lambda
164 // container reduce with initial value
165 template <typename Container,
166  typename Reduce,
167  typename Init,
168  index_t IBegin = 0,
169  index_t IEnd = Container::Size(),
170  index_t IStep = 1>
171 __host__ __device__ constexpr auto container_reduce(const Container& x,
172  Reduce reduce,
173  Init init,
174  Number<IBegin> = Number<0>{},
175  Number<IEnd> = Number<Container::Size()>{},
176  Number<IStep> = Number<1>{})
177 {
178  static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
179 
180  if constexpr(IEnd > IBegin)
181  {
182  return container_reduce_impl(
183  x, reduce, init, Number<IBegin>{}, Number<IEnd>{}, Number<IStep>{});
184  }
185  else
186  {
187  return init;
188  }
189 }
190 #endif
191 
192 template <typename TData, index_t NSize, typename Reduce>
193 __host__ __device__ constexpr auto
195 {
197 
198  TData r = init;
199 
200  static_for<NSize - 1, 0, -1>{}([&](auto i) {
201  r = f(r, x[i]);
202  y(i) = r;
203  });
204 
205  r = f(r, x[Number<0>{}]);
206  y(Number<0>{}) = r;
207 
208  return y;
209 }
210 
211 template <typename TData, index_t NSize, typename Reduce>
212 __host__ __device__ constexpr auto
214 {
216 
217  TData r = init;
218 
219  static_for<NSize - 1, 0, -1>{}([&](auto i) {
220  y(i) = r;
221  r = f(r, x[i]);
222  });
223 
224  y(Number<0>{}) = r;
225 
226  return y;
227 }
228 
229 template <index_t... Is, typename Reduce, index_t Init>
230 __host__ __device__ constexpr auto
232 {
234 }
235 
236 #if !CK_WORKAROUND_SWDEV_275126
237 // rocm4.1 compiler would crash with recursive lambda
238 template <typename... Xs, typename Reduce, typename Init>
239 __host__ __device__ constexpr auto
240 container_reverse_exclusive_scan(const Tuple<Xs...>& x, Reduce reduce, Init init)
241 {
242  constexpr index_t NSize = sizeof...(Xs);
243 
244  // f is recursive function, fs is a dummy of f
245  // i is index, y_old is current scan, r_old is current reduction
246  auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
247  auto r_new = reduce(x[i], r_old);
248 
249  auto y_new = container_push_front(y_old, r_new);
250 
251  if constexpr(i.value > 1)
252  {
253  // recursively call f/fs
254  return fs(fs, i - Number<1>{}, y_new, r_new);
255  }
256  else
257  {
258  return y_new;
259  }
260  };
261 
262  // start recursion
263  return f(f, Number<NSize - 1>{}, make_tuple(init), init);
264 }
265 #else
266 // i is index, y_old is current scan, r_old is current reduction
267 template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
268 __host__ __device__ constexpr auto container_reverse_exclusive_scan_impl(
269  const Tuple<Xs...>& x, Reduce reduce, Number<I> i, YOld y_old, ROld r_old)
270 {
271  auto r_new = reduce(x[i], r_old);
272 
273  auto y_new = container_push_front(y_old, r_new);
274 
275  if constexpr(i.value > 1)
276  {
277  // recursively call f/fs
278  return container_reverse_exclusive_scan_impl(x, reduce, i - Number<1>{}, y_new, r_new);
279  }
280  else
281  {
282  return y_new;
283  }
284 }
285 
286 template <typename... Xs, typename Reduce, typename Init>
287 __host__ __device__ constexpr auto
288 container_reverse_exclusive_scan(const Tuple<Xs...>& x, Reduce reduce, Init init)
289 {
290  constexpr index_t NSize = sizeof...(Xs);
291 
293  x, reduce, Number<NSize - 1>{}, make_tuple(init), init);
294 }
295 #endif
296 
297 // TODO: update to like container_reverse_exclusive_scan to deal with Tuple of Numebr<>
298 template <typename... Xs, typename Reduce, typename TData>
299 __host__ __device__ constexpr auto
300 container_reverse_inclusive_scan(const Tuple<Xs...>& x, Reduce f, TData init)
301 {
302  constexpr index_t NSize = sizeof...(Xs);
303 
304  Tuple<Xs...> y;
305 
306  TData r = init;
307 
308  static_for<NSize - 1, 0, -1>{}([&](auto i) {
309  r = f(r, x[i]);
310  y(i) = r;
311  });
312 
313  r = f(r, x[Number<0>{}]);
314  y(Number<0>{}) = r;
315 
316  return y;
317 }
318 
319 template <typename X, typename... Ys>
320 __host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys)
321 {
322  return container_concat(x, container_concat(ys...));
323 }
324 
325 template <typename T, index_t NX, index_t NY>
326 __host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay)
327 {
328  return unpack2(
329  [&](auto&&... zs) { return make_array(ck::forward<decltype(zs)>(zs)...); }, ax, ay);
330 }
331 
332 template <typename... X, typename... Y>
333 __host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
334 {
335  return unpack2(
336  [&](auto&&... zs) { return make_tuple(ck::forward<decltype(zs)>(zs)...); }, tx, ty);
337 }
338 
339 template <typename Container>
340 __host__ __device__ constexpr auto container_concat(const Container& x)
341 {
342  return x;
343 }
344 
345 template <typename T, index_t N, index_t... Is>
346 __host__ __device__ constexpr auto get_container_subset(const Array<T, N>& arr, Sequence<Is...>)
347 {
348  static_assert(N >= sizeof...(Is), "wrong! size");
349 
350  return make_array(arr[Number<Is>{}]...);
351 }
352 
353 template <typename... Ts, index_t... Is>
354 __host__ __device__ constexpr auto get_container_subset(const Tuple<Ts...>& tup, Sequence<Is...>)
355 {
356  static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
357 
358  return make_tuple(tup[Number<Is>{}]...);
359 }
360 
361 template <typename T, index_t N, index_t... Is>
362 __host__ __device__ constexpr void
363 set_container_subset(Array<T, N>& y, Sequence<Is...> picks, const Array<T, sizeof...(Is)>& x)
364 {
365  static_assert(N >= sizeof...(Is), "wrong! size");
366 
367  static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
368 }
369 
370 template <typename... Ys, index_t... Is, typename... Xs>
371 __host__ __device__ constexpr void
373 {
374  static_assert(sizeof...(Ys) >= sizeof...(Is) && sizeof...(Is) == sizeof...(Xs), "wrong! size");
375 
376  static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
377 }
378 
379 template <index_t... Is>
380 __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
381 {
382  using Seq = Sequence<Is...>;
383 
384  return generate_tuple(
385  [&](auto i) {
386  constexpr index_t tmp = Seq::At(i);
387  return Number<tmp>{};
388  },
389  Seq::Size());
390 }
391 
392 } // namespace ck
393 #endif
constexpr CK_TILE_HOST_DEVICE auto container_reverse_exclusive_scan_impl(const tuple< Xs... > &x, Reduce reduce, number< I > i, YOld y_old, ROld r_old)
Definition: container_helper.hpp:311
constexpr CK_TILE_HOST_DEVICE auto container_reduce_impl(const Container &x, Reduce reduce, ROld r_old, number< I > i, number< IEnd >, number< IStep >)
Definition: container_helper.hpp:174
Definition: ck.hpp:267
__host__ constexpr __device__ auto unpack2(F &&f, X &&x, Y &&y)
Definition: functional4.hpp:55
__host__ constexpr __device__ auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition: container_helper.hpp:18
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto container_push_front(const Tuple< Ts... > &a, const T &x)
Definition: container_helper.hpp:30
__host__ constexpr __device__ auto reverse_exclusive_scan_sequence(Seq, Reduce, Number< Init >)
Definition: sequence.hpp:805
__host__ constexpr __device__ auto make_array(X &&x, Xs &&... xs)
Definition: array.hpp:56
__host__ constexpr __device__ auto container_reverse_inclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition: container_helper.hpp:194
__host__ constexpr __device__ auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition: container_helper.hpp:380
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition: container_helper.hpp:213
__host__ constexpr __device__ auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition: container_helper.hpp:43
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ void set_container_subset(Array< T, N > &y, Sequence< Is... > picks, const Array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:363
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
integral_constant< index_t, N > Number
Definition: number.hpp:12
__host__ constexpr __device__ auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition: container_helper.hpp:54
__host__ constexpr __device__ auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition: container_helper.hpp:346
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: array.hpp:14
Definition: sequence.hpp:43
__host__ static constexpr __device__ index_t At(index_t I)
Definition: sequence.hpp:53
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: sequence.hpp:618
Definition: sequence.hpp:623
Definition: functional2.hpp:33