/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/tensor_space_filling_curve.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/tensor_space_filling_curve.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/tensor_space_filling_curve.hpp Source File
tensor_space_filling_curve.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/utility/math.hpp"
12 
13 namespace ck {
14 
15 template <typename TensorLengths,
16  typename DimAccessOrder,
17  typename ScalarsPerAccess,
18  bool SnakeCurved = true> // # of scalars per access in each dimension
20 {
21  static constexpr index_t nDim = TensorLengths::Size();
22 
24 
25  static constexpr index_t ScalarPerVector =
26  reduce_on_sequence(ScalarsPerAccess{}, math::multiplies{}, Number<1>{});
27 
28  static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
29  static constexpr auto dim_access_order = DimAccessOrder{};
30  static constexpr auto ordered_access_lengths =
32 
37 
38  static constexpr auto I0 = Number<0>{};
39  static constexpr auto I1 = Number<1>{};
40 
41  __host__ __device__ static constexpr index_t GetNumOfAccess()
42  {
43  static_assert(TensorLengths::Size() == ScalarsPerAccess::Size());
44  static_assert(TensorLengths{} % ScalarsPerAccess{} ==
45  typename uniform_sequence_gen<TensorLengths::Size(), 0>::type{});
46 
47  return reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}) /
49  }
50 
51  template <index_t AccessIdx1dBegin, index_t AccessIdx1dEnd>
52  static __device__ __host__ constexpr auto GetStepBetween(Number<AccessIdx1dBegin>,
54  {
55  static_assert(AccessIdx1dBegin >= 0, "1D index should be non-negative");
56  static_assert(AccessIdx1dBegin < GetNumOfAccess(), "1D index should be larger than 0");
57  static_assert(AccessIdx1dEnd >= 0, "1D index should be non-negative");
58  static_assert(AccessIdx1dEnd < GetNumOfAccess(), "1D index should be larger than 0");
59 
60  constexpr auto idx_begin = GetIndex(Number<AccessIdx1dBegin>{});
61  constexpr auto idx_end = GetIndex(Number<AccessIdx1dEnd>{});
62  return idx_end - idx_begin;
63  }
64 
65  template <index_t AccessIdx1d>
66  static __device__ __host__ constexpr auto GetForwardStep(Number<AccessIdx1d>)
67  {
68  static_assert(AccessIdx1d < GetNumOfAccess(), "1D index should be larger than 0");
70  }
71 
72  template <index_t AccessIdx1d>
73  static __device__ __host__ constexpr auto GetBackwardStep(Number<AccessIdx1d>)
74  {
75  static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
76 
77  return GetStepBetween(Number<AccessIdx1d>{}, Number<AccessIdx1d - 1>{});
78  }
79 
80  template <index_t AccessIdx1d>
81  static __device__ __host__ constexpr Index GetIndex(Number<AccessIdx1d>)
82  {
83 #if 0
84  /*
85  * \todo: TensorAdaptor::CalculateBottomIndex does NOT return constexpr as expected.
86  */
87  constexpr auto ordered_access_idx = to_index_adaptor.CalculateBottomIndex(make_multi_index(Number<AccessIdx1d>{}));
88 #else
89 
90  constexpr auto access_strides = container_reverse_exclusive_scan(
92 
93  constexpr auto idx_1d = Number<AccessIdx1d>{};
94  // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
95  // idim-th element of multidimensional index.
96  // All constexpr variables have to be captured by VALUE.
97  constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr {
98  constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr {
99  auto res = idx_1d.value;
100  auto id = 0;
101 
102  static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
103  id = res / access_strides[kdim].value;
104  res -= id * access_strides[kdim].value;
105  });
106 
107  return id;
108  };
109 
110  constexpr auto id = compute_index_impl(idim);
111  return Number<id>{};
112  };
113 
114  constexpr auto ordered_access_idx = generate_tuple(compute_index, Number<nDim>{});
115 #endif
116  constexpr auto forward_sweep = [&]() {
117  StaticallyIndexedArray<bool, nDim> forward_sweep_;
118 
119  forward_sweep_(I0) = true;
120 
121  static_for<1, nDim, 1>{}([&](auto idim) {
122  index_t tmp = ordered_access_idx[I0];
123 
125  [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
126 
127  forward_sweep_(idim) = tmp % 2 == 0;
128  });
129 
130  return forward_sweep_;
131  }();
132 
133  // calculate multi-dim tensor index
134  auto idx_md = [&]() {
135  Index ordered_idx;
136 
137  static_for<0, nDim, 1>{}([&](auto idim) {
138  ordered_idx(idim) =
139  !SnakeCurved || forward_sweep[idim]
140  ? ordered_access_idx[idim]
141  : ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
142  });
143 
145  ScalarsPerAccess{};
146  }();
147  return idx_md;
148  }
149 
150  // FIXME: rename this function
151  template <index_t AccessIdx1d>
152  static __device__ __host__ constexpr auto GetIndexTupleOfNumber(Number<AccessIdx1d>)
153  {
154  constexpr auto idx = GetIndex(Number<AccessIdx1d>{});
155 
156  return generate_tuple([&](auto i) { return Number<idx[i]>{}; }, Number<nDim>{});
157  }
158 };
159 
160 } // namespace ck
Definition: ck.hpp:267
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ index_t reduce_on_sequence(Seq, Reduce f, Number< Init >)
Definition: sequence.hpp:884
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition: container_helper.hpp:213
__host__ constexpr __device__ auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition: container_helper.hpp:43
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition: container_helper.hpp:54
Definition: array.hpp:14
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
static constexpr auto I0
Definition: tensor_space_filling_curve.hpp:38
static __device__ constexpr __host__ auto GetForwardStep(Number< AccessIdx1d >)
Definition: tensor_space_filling_curve.hpp:66
__host__ static constexpr __device__ index_t GetNumOfAccess()
Definition: tensor_space_filling_curve.hpp:41
static constexpr auto to_index_adaptor
Definition: tensor_space_filling_curve.hpp:33
static constexpr auto ordered_access_lengths
Definition: tensor_space_filling_curve.hpp:30
static constexpr index_t nDim
Definition: tensor_space_filling_curve.hpp:21
static constexpr auto I1
Definition: tensor_space_filling_curve.hpp:39
static constexpr index_t ScalarPerVector
Definition: tensor_space_filling_curve.hpp:25
static __device__ constexpr __host__ auto GetBackwardStep(Number< AccessIdx1d >)
Definition: tensor_space_filling_curve.hpp:73
static __device__ constexpr __host__ auto GetIndexTupleOfNumber(Number< AccessIdx1d >)
Definition: tensor_space_filling_curve.hpp:152
static constexpr auto access_lengths
Definition: tensor_space_filling_curve.hpp:28
static __device__ constexpr __host__ Index GetIndex(Number< AccessIdx1d >)
Definition: tensor_space_filling_curve.hpp:81
static constexpr auto dim_access_order
Definition: tensor_space_filling_curve.hpp:29
static __device__ constexpr __host__ auto GetStepBetween(Number< AccessIdx1dBegin >, Number< AccessIdx1dEnd >)
Definition: tensor_space_filling_curve.hpp:52
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:271
Definition: integral_constant.hpp:20
static constexpr T value
Definition: integral_constant.hpp:21
Definition: math.hpp:34
Definition: functional2.hpp:33
Definition: sequence.hpp:289