include/ck/utility/tuple_helper.hpp Source File

include/ck/utility/tuple_helper.hpp Source File#

Composable Kernel: include/ck/utility/tuple_helper.hpp Source File
tuple_helper.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "functional4.hpp"
7 #include "tuple.hpp"
8 #ifndef CK_CODE_GEN_RTC
9 #include "is_detected.hpp"
10 #endif
11 
12 namespace ck {
13 
14 template <typename F, index_t N>
15 __host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
16 {
17  return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); },
19 }
20 
21 template <typename F, index_t N>
22 __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
23 {
24  return unpack([&f](auto&&... xs) { return tie(f(xs)...); },
26 }
27 
28 // tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
29 template <typename... X, typename... Y>
30 __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
31  const Tuple<Y&...>& ty)
32 {
33  return unpack2(
34  [&](auto&&... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
35  tx,
36  ty);
37 }
38 
39 template <typename... X, typename... Y>
40 __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
41 {
42  return unpack2(
43  [&](auto... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
44  tx,
45  ty);
46 }
47 
48 // Support any number of tuples to concat (also 1)
49 template <typename... X>
50 __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx)
51 {
52  return tx;
53 }
54 
55 template <typename... X, typename... Tuples>
56 __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuples&... tuples)
57 {
58  return concat_tuple(tx, concat_tuple(tuples...));
59 }
60 
61 namespace detail {
62 
63 template <typename F, typename X, index_t... Is>
64 __host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
65 {
66  return make_tuple(f(x.At(Number<Is>{}))...);
67 }
68 
69 template <typename F, typename X, typename Y, index_t... Is>
70 __host__ __device__ constexpr auto
71 transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
72 {
73  return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
74 }
75 
76 template <typename F, typename X, typename Y, typename Z, index_t... Is>
77 __host__ __device__ constexpr auto
78 transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>)
79 {
80  return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
81 }
82 
83 } // namespace detail
84 
85 template <typename F, typename X>
86 __host__ __device__ constexpr auto transform_tuples(F f, const X& x)
87 {
89  f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
90 }
91 
92 template <typename F, typename X, typename Y>
93 __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
94 {
96  f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
97 }
98 
99 template <typename F, typename X, typename Y, typename Z>
100 __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
101 {
103  f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
104 }
105 
106 // By default unroll to the flatten
107 template <index_t Depth = 0, index_t MaxDepth = -1>
108 __host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element)
109 {
110  return element;
111 }
112 
113 template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
114 __host__ __device__ constexpr auto UnrollNestedTuple(const T& element)
115 {
116  return make_tuple(element);
117 }
118 
119 template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
120 __host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
121 {
122  if constexpr(Depth == MaxDepth)
123  {
124  return tuple;
125  }
126  else
127  {
128  return unpack(
129  [&](auto&&... ts) {
130  return concat_tuple(UnrollNestedTuple<Depth + 1, MaxDepth>(ts)...);
131  },
132  tuple);
133  }
134 }
135 
136 template <typename... Ts>
137 __host__ __device__ constexpr auto TupleReverse(const Tuple<Ts...>& tuple)
138 {
139  return generate_tuple(
140  [&](auto i) {
141  using Idx = Number<Tuple<Ts...>::Size() - i - 1>;
142  return tuple.At(Idx{});
143  },
145 }
146 
147 // Reduce tuple values in specific range using Function
148 template <index_t Idx, index_t End, typename F, typename... Ts>
149 __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
150 {
151  static_assert(Idx < End, "Wrong parameters for TupleReduce");
152  if constexpr(Idx + 1 == End)
153  {
154  return tuple.At(Number<Idx>{});
155  }
156  else
157  {
158  return f(tuple.At(Number<Idx>{}), TupleReduce<Idx + 1, End>(f, tuple));
159  }
160 }
161 
162 #ifndef CK_CODE_GEN_RTC
163 template <typename T>
164 using is_tuple = decltype(ck::declval<T&>().IsTuple());
165 #endif
166 
167 template <typename... Ts>
168 __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
169 {
170 #ifndef CK_CODE_GEN_RTC
171  return (is_detected<is_tuple, Ts>::value || ...);
172 #endif
173 }
174 
175 template <index_t depth = 0, typename T>
176 __host__ __device__ constexpr auto TupleDepth(const T&)
177 {
178  return depth;
179 }
180 
181 template <index_t depth = 0, typename... Ts>
182 __host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&)
183 {
184  return math::max(TupleDepth<depth + 1>(Ts{})...);
185 }
186 
187 template <index_t from, index_t to, typename... Ts>
188 __host__ __device__ constexpr auto TupleSlice(const Tuple<Ts...>& tuple)
189 {
190  return generate_tuple(
191  [&](auto i) {
192  using Idx = Number<from + i>;
193  return tuple.At(Idx{});
194  },
195  Number<to - from>{});
196 }
197 
198 } // namespace ck
__host__ constexpr __device__ auto depth(const Layout< Shape, UnrolledDescriptorType > &layout)
Get depth of the layout shape (return 0 if scalar).
Definition: layout_utils.hpp:371
__host__ constexpr __device__ auto transform_tuples_impl(F f, const X &x, Sequence< Is... >)
Definition: tuple_helper.hpp:64
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:264
__host__ constexpr __device__ auto TupleReduce(F &&f, const Tuple< Ts... > &tuple)
Definition: tuple_helper.hpp:149
__host__ constexpr __device__ auto IsNestedTuple(const Tuple< Ts... > &)
Definition: tuple_helper.hpp:168
__host__ constexpr __device__ auto unpack2(F &&f, X &&x, Y &&y)
Definition: functional4.hpp:55
__host__ constexpr __device__ auto concat_tuple(const Tuple< X... > &tx, const Tuple< Y... > &ty)
Definition: tuple_helper.hpp:40
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:22
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto TupleReverse(const Tuple< Ts... > &tuple)
Definition: tuple_helper.hpp:137
__host__ constexpr __device__ auto UnrollNestedTuple(const Tuple<> &element)
Definition: tuple_helper.hpp:108
__host__ constexpr __device__ auto transform_tuples(F f, const X &x)
Definition: tuple_helper.hpp:86
__host__ constexpr __device__ auto TupleDepth(const T &)
Definition: tuple_helper.hpp:176
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:30
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: is_detected.hpp:34
__host__ constexpr __device__ auto unpack(F &&f, X &&x)
Definition: functional4.hpp:46
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition: tuple_helper.hpp:164
__host__ constexpr __device__ auto TupleSlice(const Tuple< Ts... > &tuple)
Definition: tuple_helper.hpp:188
Definition: sequence.hpp:43
Definition: tuple.hpp:186
Definition: tuple.hpp:117
Definition: sequence.hpp:241
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
Definition: integral_constant.hpp:10