/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/algorithm/space_filling_curve.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/algorithm/space_filling_curve.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/algorithm/space_filling_curve.hpp Source File
space_filling_curve.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
12 
13 namespace ck_tile {
14 
15 template <typename TensorLengths,
16  typename DimAccessOrder,
17  typename ScalarsPerAccess,
18  bool SnakeCurved = true> // # of scalars per access in each dimension
20 {
21  static constexpr index_t TensorSize =
22  reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{});
23  static_assert(0 < TensorSize,
24  "space_filling_curve should be used to access a non-empty tensor");
25 
26  static constexpr index_t nDim = TensorLengths::size();
27 
29 
30  static constexpr index_t ScalarPerVector =
31  reduce_on_sequence(ScalarsPerAccess{}, multiplies{}, number<1>{});
32 
33  static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
34  static constexpr auto dim_access_order = DimAccessOrder{};
35  static constexpr auto ordered_access_lengths =
37 
42 
43  static constexpr auto I0 = number<0>{};
44  static constexpr auto I1 = number<1>{};
45 
47  {
48  static_assert(TensorLengths::size() == ScalarsPerAccess::size());
49  static_assert(TensorLengths{} % ScalarsPerAccess{} ==
50  typename uniform_sequence_gen<TensorLengths::size(), 0>::type{});
51 
52  return reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}) / ScalarPerVector;
53  }
54 
55  template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
58  {
59  static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(),
60  "1D index out of range");
61  static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(),
62  "1D index out of range");
63 
64  constexpr auto idx_head = get_index(number<AccessIdx1dHead>{});
65  constexpr auto idx_tail = get_index(number<AccessIdx1dTail>{});
66  return idx_tail - idx_head;
67  }
68 
69  template <index_t AccessIdx1d>
71  {
72  static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0");
74  }
75 
76  template <index_t AccessIdx1d>
78  {
79  static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
80 
81  return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
82  }
83 
84  // Do not use this function directly!
85  // TODO: can refactor into generic lambda in the future
86  template <index_t AccessIdx1d>
88  {
89 #if 0
90  /*
91  * \todo: tensor_adaptor::calculate_bottom_index does NOT return constexpr as expected.
92  */
93  constexpr auto ordered_access_idx = to_index_adaptor.calculate_bottom_index(make_multi_index(number<AccessIdx1d>{}));
94 #else
95 
96  constexpr auto access_strides =
98 
99  constexpr auto idx_1d = number<AccessIdx1d>{};
100  // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
101  // idim-th element of multidimensional index.
102  // All constexpr variables have to be captured by VALUE.
103  constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr {
104  constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr {
105  auto res = idx_1d.value;
106  auto id = 0;
107 
108  static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
109  id = res / access_strides[kdim].value;
110  res -= id * access_strides[kdim].value;
111  });
112 
113  return id;
114  };
115 
116  constexpr auto id = compute_index_impl(idim);
117  return number<id>{};
118  };
119 
120  constexpr auto ordered_access_idx = generate_tuple(compute_index, number<nDim>{});
121 #endif
122  constexpr auto forward_sweep = [&]() {
124 
125  forward_sweep_(I0) = true;
126 
127  static_for<1, nDim, 1>{}([&](auto idim) {
128  index_t tmp = ordered_access_idx[I0];
129 
131  [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
132 
133  forward_sweep_(idim) = tmp % 2 == 0;
134  });
135 
136  return forward_sweep_;
137  }();
138 
139  // calculate multi-dim tensor index
140  auto idx_md = [&]() {
141  Index ordered_idx;
142 
143  static_for<0, nDim, 1>{}([&](auto idim) {
144  ordered_idx(idim) =
145  !SnakeCurved || forward_sweep[idim]
146  ? ordered_access_idx[idim]
147  : ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
148  });
149 
151  ScalarsPerAccess{};
152  }();
153  return idx_md;
154  }
155 
156  // FIXME: return tuple of number<>, which is compile time only variable
157  template <index_t AccessIdx1d>
159  {
160  constexpr auto idx = _get_index(number<AccessIdx1d>{});
161 
162  return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
163  }
164 };
165 
166 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto container_reorder_given_old2new(const array< TData, NSize > &old_array, sequence< IRs... > old2new)
Definition: container_helper.hpp:48
constexpr CK_TILE_HOST_DEVICE auto container_reorder_given_new2old(const array< TData, NSize > &old_array, sequence< IRs... >)
Definition: container_helper.hpp:39
constexpr CK_TILE_HOST_DEVICE auto container_reverse_exclusive_scan(const array< TData, NSize > &x, Reduce f, Init init)
Definition: container_helper.hpp:240
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:979
constexpr CK_TILE_HOST_DEVICE auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:359
constexpr CK_TILE_HOST_DEVICE auto make_multi_index(Xs &&... xs)
Definition: multi_index.hpp:20
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:299
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: math.hpp:98
Definition: sequence.hpp:49
Definition: space_filling_curve.hpp:20
static constexpr auto to_index_adaptor
Definition: space_filling_curve.hpp:38
static constexpr index_t TensorSize
Definition: space_filling_curve.hpp:21
static constexpr auto ordered_access_lengths
Definition: space_filling_curve.hpp:35
static constexpr auto I1
Definition: space_filling_curve.hpp:44
static constexpr CK_TILE_HOST_DEVICE auto get_forward_step(number< AccessIdx1d >)
Definition: space_filling_curve.hpp:70
static constexpr CK_TILE_HOST_DEVICE auto get_backward_step(number< AccessIdx1d >)
Definition: space_filling_curve.hpp:77
static constexpr CK_TILE_HOST_DEVICE auto get_index(number< AccessIdx1d >)
Definition: space_filling_curve.hpp:158
static constexpr index_t ScalarPerVector
Definition: space_filling_curve.hpp:30
static constexpr CK_TILE_HOST_DEVICE auto get_step_between(number< AccessIdx1dHead >, number< AccessIdx1dTail >)
Definition: space_filling_curve.hpp:56
static constexpr auto dim_access_order
Definition: space_filling_curve.hpp:34
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_access()
Definition: space_filling_curve.hpp:46
static constexpr CK_TILE_HOST_DEVICE Index _get_index(number< AccessIdx1d >)
Definition: space_filling_curve.hpp:87
static constexpr index_t nDim
Definition: space_filling_curve.hpp:26
static constexpr auto I0
Definition: space_filling_curve.hpp:43
static constexpr auto access_lengths
Definition: space_filling_curve.hpp:33
Definition: functional.hpp:43
Definition: sequence.hpp:311