/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/functional_with_tuple.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/functional_with_tuple.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/functional_with_tuple.hpp Source File
functional_with_tuple.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 // This file should not be included inside tuple.hpp!
7 
15 #include <stdint.h>
16 #include <utility>
17 
18 namespace ck_tile {
19 
20 namespace detail {
21 
22 // RemainLengths: sequence<...>
23 // Orders: sequence<...>
24 template <class RemainLengths, class RamainUnpacks, class Orders>
26 {
28  {
29  static_assert(RemainLengths::size() > 0, "wrong! should not get here");
30  static_assert(RamainUnpacks::size() > 0, "wrong! should not get here");
31  }
32 
33  template <class F, class CurrentUnpackIds>
34  CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds) const
35  {
36  constexpr index_t pack_len = RamainUnpacks::front();
37  static_for<0, RemainLengths::front(), pack_len>{}([=](auto I) {
38  constexpr auto new_pack = generate_tuple(
39  [&](auto idx_) {
40  constexpr auto i_new_pack = number<I + idx_ % pack_len>{};
41  constexpr auto i_pre_pack = number<idx_ / pack_len>{};
42  return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
43  },
44  number<CurrentUnpackIds::size() * pack_len>{});
45 
46  static_uford_impl<decltype(RemainLengths::pop_front()),
47  decltype(RamainUnpacks::pop_front()),
48  Orders>{}(f, new_pack);
49  });
50  }
51 };
52 
53 template <class Orders>
54 struct static_uford_impl<sequence<>, sequence<>, Orders>
55 {
56  template <class F, class PackedId>
57  CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId) const
58  {
59  constexpr auto origin_packs = transform_tuples(
60  [](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
61  unpack(f, origin_packs);
62  }
63 };
64 
65 template <class RemainLengths, class RamainUnpacks, class Orders>
67 {
68  template <class F, class CurrentUnpackIds, index_t current_acc>
69  CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number<current_acc>) const
70  {
71  constexpr auto r_lens_stride =
73  constexpr auto r_upks_stride =
75 
76  constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front();
77  constexpr index_t pack_len = RamainUnpacks::front();
78  constexpr index_t current_idx = (current_acc / current_stride) * pack_len;
79 
80  constexpr auto new_pack = generate_tuple(
81  [&](auto idx_) {
82  constexpr auto i_new_pack = number<current_idx + idx_ % pack_len>{};
83  constexpr auto i_pre_pack = number<idx_ / pack_len>{};
84  return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
85  },
86  number<CurrentUnpackIds::size() * pack_len>{});
87 
88  static_uford_one_shot_impl<decltype(RemainLengths::pop_front()),
89  decltype(RamainUnpacks::pop_front()),
90  Orders>{}(f, new_pack, number<current_acc % current_stride>{});
91  }
92 };
93 
94 template <class Orders>
96 {
97  template <class F, class PackedId, index_t current_acc>
98  CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId, number<current_acc>) const
99  {
100  constexpr auto origin_packs = transform_tuples(
101  [](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
102  unpack(f, origin_packs);
103  }
104 };
105 
106 } // namespace detail
107 
108 // TODO: we may unify static_ford/static_uford in the future
109 //
110 // loop over nd space(sequence) with packs
111 // you must make sure the function passed in has same number of argument
112 //
113 // e.g.
114 // Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2>
115 // static_uford<Lengths, Unpacks>{}([&](auto i_0, auto i_1){}); // require 2 args(packs)
116 //
117 // loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1>
118 // loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3>
119 // loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1>
120 // loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3>
121 // loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1>
122 // loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3>
123 // loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1>
124 // ...
125 template <class Lengths,
126  class Unpacks = typename uniform_sequence_gen<Lengths::size(), 1>::type,
127  class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
129 {
130  static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies{}, number<1>{});
131 
133  {
134  static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
135  static_assert(Lengths::size() == Unpacks::size(), "wrong! inconsistent size");
136  static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
137  static_for<0, Lengths::size(), 1>{}(
138  [&](auto i) { static_assert(Lengths{}.at(i) % Unpacks{}.at(i) == 0); });
139  }
140 
142  {
143  using L_ = decltype(Lengths{} / Unpacks{});
144 
145  return reduce_on_sequence(L_{}, multiplies{}, number<1>{});
146  }
147 
148  // F signature: F(sequence<...> multi_id...)
149  // multi_id is the unordered multi-index
150  template <class F>
151  CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
152  {
153  constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
154  constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
155  detail::static_uford_impl<decltype(ordered_lengths), decltype(ordered_unpacks), Orders>{}(
156  f, make_tuple(sequence<>{}));
157  }
158 
159  // this version is friendly for issue function one by one
160  template <class F, index_t i_access>
162  {
163  static_assert(i_access < get_num_of_access());
164  constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
165  constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
166  detail::static_uford_one_shot_impl<decltype(ordered_lengths),
167  decltype(ordered_unpacks),
168  Orders>{}(
170  }
171 };
172 
173 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto reverse_exclusive_scan_sequence(Seq, Reduce, number< Init >)
Definition: sequence.hpp:860
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:979
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto unpack(F &&f, X &&x)
Definition: functional.hpp:200
constexpr CK_TILE_HOST_DEVICE auto transform_tuples(F f, const X &x)
Definition: tuple.hpp:505
Definition: integral_constant.hpp:13
constexpr CK_TILE_HOST_DEVICE void operator()(F f, PackedId) const
Definition: functional_with_tuple.hpp:57
Definition: functional_with_tuple.hpp:26
constexpr CK_TILE_HOST_DEVICE void operator()(F f, CurrentUnpackIds) const
Definition: functional_with_tuple.hpp:34
constexpr CK_TILE_HOST_DEVICE static_uford_impl()
Definition: functional_with_tuple.hpp:27
constexpr CK_TILE_HOST_DEVICE void operator()(F f, PackedId, number< current_acc >) const
Definition: functional_with_tuple.hpp:98
Definition: functional_with_tuple.hpp:67
constexpr CK_TILE_HOST_DEVICE void operator()(F f, CurrentUnpackIds, number< current_acc >) const
Definition: functional_with_tuple.hpp:69
Definition: math.hpp:98
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: functional_with_tuple.hpp:129
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_access()
Definition: functional_with_tuple.hpp:141
static constexpr index_t num_packs
Definition: functional_with_tuple.hpp:130
constexpr CK_TILE_HOST_DEVICE void operator()(F f) const
Definition: functional_with_tuple.hpp:151
constexpr CK_TILE_HOST_DEVICE void operator()(F f, number< i_access >) const
Definition: functional_with_tuple.hpp:161
constexpr CK_TILE_HOST_DEVICE static_uford()
Definition: functional_with_tuple.hpp:132