/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/tuple_helper.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/tuple_helper.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/tuple_helper.hpp Source File
tuple_helper.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "functional4.hpp"
7 #include "tuple.hpp"
8 #ifndef CK_CODE_GEN_RTC
9 #include "is_detected.hpp"
10 #endif
11 
12 namespace ck {
13 
14 template <typename F, index_t... ids>
15 __host__ __device__ constexpr auto generate_tuple_for(F&& f, Sequence<ids...>)
16 {
17  return make_tuple(f(Number<ids>{})...);
18 }
19 
20 template <typename F, index_t N>
21 __host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
22 {
24 }
25 
26 template <typename F, index_t N>
27 __host__ __device__ constexpr auto generate_tuple(F&& f, LongNumber<N>)
28 {
29  return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); },
31 }
32 
33 template <typename F, index_t N>
34 __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
35 {
36  return unpack([&f](auto&&... xs) { return tie(f(xs)...); },
38 }
39 
40 // tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
41 template <typename... X, typename... Y>
42 __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
43  const Tuple<Y&...>& ty)
44 {
45  return unpack2(
46  [&](auto&&... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
47  tx,
48  ty);
49 }
50 
51 template <typename... X, typename... Y>
52 __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
53 {
54  return unpack2(
55  [&](auto... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
56  tx,
57  ty);
58 }
59 
60 // Support any number of tuples to concat (also 1)
61 template <typename... X>
62 __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx)
63 {
64  return tx;
65 }
66 
67 template <typename... X, typename... Tuples>
68 __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuples&... tuples)
69 {
70  return concat_tuple(tx, concat_tuple(tuples...));
71 }
72 
73 namespace detail {
74 
75 template <typename F, typename X, index_t... Is>
76 __host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
77 {
78  return make_tuple(f(x.At(Number<Is>{}))...);
79 }
80 
81 template <typename F, typename X, typename Y, index_t... Is>
82 __host__ __device__ constexpr auto
83 transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
84 {
85  return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
86 }
87 
88 template <typename F, typename X, typename Y, typename Z, index_t... Is>
89 __host__ __device__ constexpr auto
90 transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>)
91 {
92  return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
93 }
94 
95 } // namespace detail
96 
97 template <typename F, typename X>
98 __host__ __device__ constexpr auto transform_tuples(F f, const X& x)
99 {
101  f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
102 }
103 
104 template <typename F, typename X, typename Y>
105 __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
106 {
108  f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
109 }
110 
111 template <typename F, typename X, typename Y, typename Z>
112 __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
113 {
115  f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
116 }
117 
118 // By default unroll to the flatten
119 template <index_t Depth = 0, index_t MaxDepth = -1>
120 __host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element)
121 {
122  return element;
123 }
124 
125 template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
126 __host__ __device__ constexpr auto UnrollNestedTuple(const T& element)
127 {
128  return make_tuple(element);
129 }
130 
131 template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
132 __host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
133 {
134  if constexpr(Depth == MaxDepth)
135  {
136  return tuple;
137  }
138  else
139  {
140  return unpack(
141  [&](auto&&... ts) {
142  return concat_tuple(UnrollNestedTuple<Depth + 1, MaxDepth>(ts)...);
143  },
144  tuple);
145  }
146 }
147 
148 template <typename... Ts>
149 __host__ __device__ constexpr auto TupleReverse(const Tuple<Ts...>& tuple)
150 {
151  return generate_tuple(
152  [&](auto i) {
153  using Idx = Number<Tuple<Ts...>::Size() - i - 1>;
154  return tuple.At(Idx{});
155  },
157 }
158 
159 // Reduce tuple values in specific range using Function
160 template <index_t Idx, index_t End, typename F, typename... Ts>
161 __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
162 {
163  static_assert(Idx < End, "Wrong parameters for TupleReduce");
164  if constexpr(Idx + 1 == End)
165  {
166  return tuple.At(Number<Idx>{});
167  }
168  else
169  {
170  return f(tuple.At(Number<Idx>{}), TupleReduce<Idx + 1, End>(f, tuple));
171  }
172 }
173 
174 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
175 template <typename T>
176 using is_tuple = decltype(ck::declval<T&>().IsTuple());
177 #endif
178 
179 template <typename... Ts>
180 __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
181 {
182 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
183  return (is_detected<is_tuple, Ts>::value || ...);
184 #endif
185 }
186 
187 template <index_t depth = 0, typename T>
188 __host__ __device__ constexpr auto TupleDepth(const T&)
189 {
190  return depth;
191 }
192 
193 template <index_t depth = 0, typename... Ts>
194 __host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&)
195 {
196  return math::max(TupleDepth<depth + 1>(Ts{})...);
197 }
198 
199 template <index_t from, index_t to, typename... Ts>
200 __host__ __device__ constexpr auto TupleSlice(const Tuple<Ts...>& tuple)
201 {
202  return generate_tuple(
203  [&](auto i) {
204  using Idx = Number<from + i>;
205  return tuple.At(Idx{});
206  },
207  Number<to - from>{});
208 }
209 
210 } // namespace ck
__host__ constexpr __device__ auto depth(const Layout< Shape, UnrolledDescriptorType > &layout)
Get depth of the layout shape (return 0 if scalar).
Definition: layout_utils.hpp:371
__host__ constexpr __device__ auto transform_tuples_impl(F f, const X &x, Sequence< Is... >)
Definition: tuple_helper.hpp:76
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:267
__host__ constexpr __device__ auto TupleReduce(F &&f, const Tuple< Ts... > &tuple)
Definition: tuple_helper.hpp:161
__host__ constexpr __device__ auto IsNestedTuple(const Tuple< Ts... > &)
Definition: tuple_helper.hpp:180
__host__ constexpr __device__ auto unpack2(F &&f, X &&x, Y &&y)
Definition: functional4.hpp:55
__host__ constexpr __device__ auto concat_tuple(const Tuple< X... > &tx, const Tuple< Y... > &ty)
Definition: tuple_helper.hpp:52
typename __make_integer_seq< impl::__integer_sequence, index_t, N >::seq_type make_index_sequence
Definition: sequence.hpp:200
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto TupleReverse(const Tuple< Ts... > &tuple)
Definition: tuple_helper.hpp:149
__host__ constexpr __device__ auto UnrollNestedTuple(const Tuple<> &element)
Definition: tuple_helper.hpp:120
__host__ constexpr __device__ auto transform_tuples(F f, const X &x)
Definition: tuple_helper.hpp:98
__host__ constexpr __device__ auto TupleDepth(const T &)
Definition: tuple_helper.hpp:188
__host__ constexpr __device__ auto generate_tuple_for(F &&f, Sequence< ids... >)
Definition: tuple_helper.hpp:15
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: is_detected.hpp:34
__host__ constexpr __device__ auto unpack(F &&f, X &&x)
Definition: functional4.hpp:46
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition: tuple_helper.hpp:176
__host__ constexpr __device__ auto TupleSlice(const Tuple< Ts... > &tuple)
Definition: tuple_helper.hpp:200
Definition: sequence.hpp:43
Definition: tuple.hpp:186
Definition: tuple.hpp:117
Definition: sequence.hpp:256
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:271
Definition: integral_constant.hpp:20