/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/functional3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/functional3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/functional3.hpp Source File
functional3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/ck.hpp"
11 
12 namespace ck {
13 
14 namespace detail {
15 
16 // RemainLengths: Sequence<...>
17 // Orders: Sequence<...>
18 template <class RemainLengths, class Orders>
20 {
21  __host__ __device__ constexpr static_ford_impl()
22  {
23  static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
24  }
25 
26  // F signature: F(Sequence<...>)
27  // CurrentOrderedId: Sequence<...>
28  template <class F, class CurrentOrderedId>
29  __host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const
30  {
31  static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
32  static_ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
33  f, CurrentOrderedId::PushBack(I));
34  });
35  }
36 };
37 
38 template <class Orders>
39 struct static_ford_impl<Sequence<>, Orders>
40 {
41  // F signature: F(Sequence<...>)
42  // OrderedId: Sequence<...>
43  template <class F, class OrderedId>
44  __host__ __device__ constexpr void operator()(F f, OrderedId) const
45  {
46  // retrive unordered Id
47  f(OrderedId::ReorderGivenOld2New(Orders{}));
48  }
49 };
50 
51 // RemainLengths: Sequence<...>
52 // Orders: Sequence<...>
53 template <class RemainLengths, class Orders>
54 struct ford_impl
55 {
56  __host__ __device__ constexpr ford_impl()
57  {
58  static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
59  }
60 
61  // F signature: F(Array<...> multi_id)
62  // CurrentOrderdId: Array<...>
63  template <class F, class CurrentOrderedId>
64  __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
65  {
66  for(index_t i = 0; i < RemainLengths::Front(); ++i)
67  {
68  ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
69  f, container_push_back(current_ordered_id, i));
70  }
71  }
72 };
73 
74 template <class Orders>
75 struct ford_impl<Sequence<>, Orders>
76 {
77  // F signature: F(Array<...> multi_id)
78  // CurrentOrderdId: Array<...>
79  template <class F, class CurrentOrderedId>
80  __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
81  {
82  // retrive unordered Id
83  f(container_reorder_given_old2new(current_ordered_id, Orders{}));
84  }
85 };
86 
87 } // namespace detail
88 
89 // Lengths is Sequence<...>, it is the length of each dimension for
90 // N-dimensional loop
91 // Orders is Sequence<...>, it is the order of dimension in which static_ford
92 // will loop over each
93 // dimension
94 template <class Lengths,
95  class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
97 {
98  __host__ __device__ constexpr static_ford()
99  {
100  static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
101  static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
102  }
103 
104  // F signature: F(Sequence<...> multi_id)
105  // multi_id is the unordered multi-index
106  template <class F>
107  __host__ __device__ constexpr void operator()(F f) const
108  {
109  constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
110  detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
111  }
112 };
113 
114 // Lengths is Sequence<...>, it is the length of each dimension for
115 // N-dimensional loop
116 // Orders is Sequence<...>, it is the order of dimension in which ford will loop
117 // over each
118 // dimension
119 template <class Lengths,
120  class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
121 struct ford
122 {
123  __host__ __device__ constexpr ford()
124  {
125  static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
126  static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
127  }
128 
129  // F signature: F(Array<...> multi_id)
130  // multi_id is the unordered multi-index
131  template <class F>
132  __host__ __device__ constexpr void operator()(F f) const
133  {
134  constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
135 
136  for(index_t i = 0; i < ordered_lengths.Front(); ++i)
137  {
138  detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
139  make_multi_index(i));
140  }
141  }
142 };
143 
144 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition: container_helper.hpp:18
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition: container_helper.hpp:54
Definition: sequence.hpp:43
__host__ constexpr __device__ void operator()(F f, CurrentOrderedId current_ordered_id) const
Definition: functional3.hpp:80
Definition: functional3.hpp:55
__host__ constexpr __device__ void operator()(F f, CurrentOrderedId current_ordered_id) const
Definition: functional3.hpp:64
__host__ constexpr __device__ ford_impl()
Definition: functional3.hpp:56
__host__ constexpr __device__ void operator()(F f, OrderedId) const
Definition: functional3.hpp:44
Definition: functional3.hpp:20
__host__ constexpr __device__ static_ford_impl()
Definition: functional3.hpp:21
__host__ constexpr __device__ void operator()(F f, CurrentOrderedId) const
Definition: functional3.hpp:29
Definition: functional3.hpp:122
__host__ constexpr __device__ void operator()(F f) const
Definition: functional3.hpp:132
__host__ constexpr __device__ ford()
Definition: functional3.hpp:123
Definition: functional2.hpp:33
Definition: functional3.hpp:97
__host__ constexpr __device__ void operator()(F f) const
Definition: functional3.hpp:107
__host__ constexpr __device__ static_ford()
Definition: functional3.hpp:98