/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_reduce.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_reduce.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_reduce.hpp Source File
reference_reduce.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <thread>
9 
10 namespace ck_tile {
11 
12 template <typename XDataType, typename ComputeDataType, typename YDataType, typename ReduceOp>
13 CK_TILE_HOST void
14 reference_reduce(const HostTensor<XDataType>& x_m_n, HostTensor<YDataType>& y_m, ReduceOp reduce_op)
15 {
16  auto f = [&](auto m) {
17  const int N = x_m_n.mDesc.get_lengths()[1];
18 
19  ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
20 
21  for(int n = 0; n < N; ++n)
22  {
23  const ComputeDataType v_a = type_convert<ComputeDataType>(x_m_n(m, n));
24 
25  v_acc = reduce_op(v_acc, v_a);
26  }
27 
28  y_m(m) = ck_tile::type_convert<YDataType>(v_acc);
29  };
30 
31  make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
32 }
33 
34 // Generic reference reduce for arbitrary dimensions
35 template <
36  typename XDataType,
37  typename ComputeDataType,
38  typename YDataType,
39  typename ReduceOp,
40  typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to keep
41  typename ReduceDims> // Expected type: ck_tile::sequence<...> containing dimension indices to
42  // reduce
44  HostTensor<YDataType>& y_tensor,
45  ReduceOp reduce_op,
46  KeptDim kept_dim,
47  ReduceDims reduce_dims)
48 {
49  const auto& x_lengths = x_tensor.mDesc.get_lengths();
50 
51  // Calculate total kept elements (product of all kept dimension lengths)
52  index_t total_kept_elements = 1;
53  static_for<0, kept_dim.size(), 1>{}(
54  [&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
55 
56  // Calculate total reduce elements (product of all reduce dimension lengths)
57  index_t total_reduce_elements = 1;
58  static_for<0, reduce_dims.size(), 1>{}(
59  [&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
60 
61  auto f = [&](auto linear_kept_idx) {
62  ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
63 
64  // Convert linear kept index to multi-dimensional kept indices
65  std::vector<index_t> kept_indices(kept_dim.size());
66  index_t temp_kept = linear_kept_idx;
67  static_for<0, kept_dim.size(), 1>{}([&](auto i) {
68  constexpr auto dim_idx = kept_dim.size() - 1 - i;
69  constexpr auto dim = kept_dim.at(dim_idx);
70  const auto len = x_lengths[dim];
71  kept_indices[dim_idx] = temp_kept % len;
72  temp_kept /= len;
73  });
74 
75  for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
76  {
77  // Convert linear reduce index to multi-dimensional reduce indices
78  std::vector<index_t> reduce_indices(reduce_dims.size());
79  index_t temp_reduce = reduce_idx;
80  static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
81  constexpr auto dim_idx = reduce_dims.size() - 1 - i;
82  constexpr auto dim = reduce_dims.at(dim_idx);
83  const auto len = x_lengths[dim];
84  reduce_indices[dim_idx] = temp_reduce % len;
85  temp_reduce /= len;
86  });
87 
88  // Build full input tensor indices by combining kept and reduce indices
89  std::vector<std::size_t> full_indices(x_lengths.size(), 0);
90  static_for<0, kept_dim.size(), 1>{}(
91  [&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
92  static_for<0, reduce_dims.size(), 1>{}(
93  [&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
94 
95  // Access input tensor element
96  const auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
97 
98  v_acc = reduce_op(v_acc, v_a);
99  }
100 
101  // Calculate output tensor index using kept indices
102  // The output tensor has the same structure as the kept dimensions
103  std::vector<std::size_t> y_indices(kept_dim.size());
104  static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
105 
106  y_tensor(y_indices) = type_convert<YDataType>(v_acc);
107  };
108 
109  make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
110 }
111 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_reduce(const HostTensor< XDataType > &x_m_n, HostTensor< YDataType > &y_m, ReduceOp reduce_op)
Definition: reference_reduce.hpp:14
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
Descriptor mDesc
Definition: host_tensor.hpp:800
Definition: functional.hpp:43