/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_reduce.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_reduce.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_reduce.hpp Source File
reference_reduce.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 #include <thread>
10 
11 namespace ck_tile {
12 
13 template <typename XDataType, typename ComputeDataType, typename YDataType, typename ReduceOp>
14 CK_TILE_HOST void
15 reference_reduce(const HostTensor<XDataType>& x_m_n, HostTensor<YDataType>& y_m, ReduceOp reduce_op)
16 {
17  auto f = [&](auto m) {
18  const int N = x_m_n.mDesc.get_lengths()[1];
19 
20  ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
21 
22  for(int n = 0; n < N; ++n)
23  {
24  const ComputeDataType v_a = type_convert<ComputeDataType>(x_m_n(m, n));
25 
26  v_acc = reduce_op(v_acc, v_a);
27  }
28 
29  y_m(m) = ck_tile::type_convert<YDataType>(v_acc);
30  };
31 
32  make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
33 }
34 
35 // Generic reference reduce for arbitrary dimensions
36 template <
37  typename XDataType,
38  typename ComputeDataType,
39  typename YDataType,
40  typename ReduceOp,
41  typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to keep
42  typename ReduceDims> // Expected type: ck_tile::sequence<...> containing dimension indices to
43  // reduce
45  HostTensor<YDataType>& y_tensor,
46  ReduceOp reduce_op,
47  KeptDim kept_dim,
48  ReduceDims reduce_dims)
49 {
50  const auto& x_lengths = x_tensor.mDesc.get_lengths();
51 
52  // Calculate total kept elements (product of all kept dimension lengths)
53  index_t total_kept_elements = 1;
54  static_for<0, kept_dim.size(), 1>{}(
55  [&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
56 
57  // Calculate total reduce elements (product of all reduce dimension lengths)
58  index_t total_reduce_elements = 1;
59  static_for<0, reduce_dims.size(), 1>{}(
60  [&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
61 
62  auto f = [&](auto linear_kept_idx) {
63  ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
64 
65  // Convert linear kept index to multi-dimensional kept indices
66  std::vector<index_t> kept_indices(kept_dim.size());
67  index_t temp_kept = linear_kept_idx;
68  static_for<0, kept_dim.size(), 1>{}([&](auto i) {
69  constexpr auto dim_idx = kept_dim.size() - 1 - i;
70  constexpr auto dim = kept_dim.at(dim_idx);
71  const auto len = x_lengths[dim];
72  kept_indices[dim_idx] = temp_kept % len;
73  temp_kept /= len;
74  });
75 
76  for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
77  {
78  // Convert linear reduce index to multi-dimensional reduce indices
79  std::vector<index_t> reduce_indices(reduce_dims.size());
80  index_t temp_reduce = reduce_idx;
81  static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
82  constexpr auto dim_idx = reduce_dims.size() - 1 - i;
83  constexpr auto dim = reduce_dims.at(dim_idx);
84  const auto len = x_lengths[dim];
85  reduce_indices[dim_idx] = temp_reduce % len;
86  temp_reduce /= len;
87  });
88 
89  // Build full input tensor indices by combining kept and reduce indices
90  std::vector<std::size_t> full_indices(x_lengths.size(), 0);
91  static_for<0, kept_dim.size(), 1>{}(
92  [&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
93  static_for<0, reduce_dims.size(), 1>{}(
94  [&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
95 
96  // Access input tensor element
97  const auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
98 
99  v_acc = reduce_op(v_acc, v_a);
100  }
101 
102  // Calculate output tensor index using kept indices
103  // The output tensor has the same structure as the kept dimensions
104  std::vector<std::size_t> y_indices(kept_dim.size());
105  static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
106 
107  y_tensor(y_indices) = type_convert<YDataType>(v_acc);
108  };
109 
110  make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
111 }
112 
113 template <typename XDataType,
114  typename ComputeDataType,
115  typename YDataType,
116  typename YRefTuple,
117  typename ReduceOps, // Expected type: ck_tile::tuple<...> containing reduce operations
118  typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to
119  // keep
120  typename ReduceDims, // Expected type: ck_tile::sequence<...> containing dimension indices
121  // to reduce
122  typename ElementWiseOps,
123  typename AccElementWiseOps>
125  YRefTuple& y_tensor_tuple,
126  ReduceOps reduce_ops,
127  KeptDim kept_dim,
128  ReduceDims reduce_dims,
129  ElementWiseOps elementwise_ops,
130  AccElementWiseOps accumulator_ops)
131 {
132  const auto& x_lengths = x_tensor.mDesc.get_lengths();
133 
134  // Calculate total kept elements (product of all kept dimension lengths)
135  index_t total_kept_elements = 1;
136  static_for<0, kept_dim.size(), 1>{}(
137  [&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
138 
139  // Calculate total reduce elements (product of all reduce dimension lengths)
140  index_t total_reduce_elements = 1;
141  static_for<0, reduce_dims.size(), 1>{}(
142  [&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
143 
144  auto f = [&](auto linear_kept_idx) {
145  // Initialize accumulators for each reduction operation
146  auto v_acc_tuple = ck_tile::generate_tuple(
147  [&](auto i) {
148  return reduce_ops.template at<i>().template GetIdentityValue<ComputeDataType>();
149  },
150  number<reduce_ops.size()>{});
151 
152  // Convert linear kept index to multi-dimensional kept indices
153  std::vector<index_t> kept_indices(kept_dim.size());
154  index_t temp_kept = linear_kept_idx;
155  static_for<0, kept_dim.size(), 1>{}([&](auto i) {
156  constexpr auto dim_idx = kept_dim.size() - 1 - i;
157  constexpr auto dim = kept_dim.at(dim_idx);
158  const auto len = x_lengths[dim];
159  kept_indices[dim_idx] = temp_kept % len;
160  temp_kept /= len;
161  });
162 
163  for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
164  {
165  // Convert linear reduce index to multi-dimensional reduce indices
166  std::vector<index_t> reduce_indices(reduce_dims.size());
167  index_t temp_reduce = reduce_idx;
168  static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
169  constexpr auto dim_idx = reduce_dims.size() - 1 - i;
170  constexpr auto dim = reduce_dims.at(dim_idx);
171  const auto len = x_lengths[dim];
172  reduce_indices[dim_idx] = temp_reduce % len;
173  temp_reduce /= len;
174  });
175 
176  // Build full input tensor indices by combining kept and reduce indices
177  std::vector<std::size_t> full_indices(x_lengths.size(), 0);
178  static_for<0, kept_dim.size(), 1>{}(
179  [&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
180  static_for<0, reduce_dims.size(), 1>{}(
181  [&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
182 
183  // Access input tensor element
184  auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
185 
186  // Apply each reduction operation
187  static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
188  // Apply element-wise operation before reduction
189  elementwise_ops.at(i)(v_a, v_a);
190 
191  v_acc_tuple.template at<i>() =
192  reduce_ops.template at<i>()(v_acc_tuple.template at<i>(), v_a);
193  });
194  }
195 
196  static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
197  // Apply accumulator element-wise operation after reduction
198  accumulator_ops.at(i)(v_acc_tuple.template at<i>(), v_acc_tuple.template at<i>());
199  });
200 
201  // Calculate output tensor index using kept indices
202  // The output tensor has the same structure as the kept dimensions
203  std::vector<std::size_t> y_indices(kept_dim.size());
204  static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
205 
206  // Store results for each reduction operation in the output tensor
207  static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
208  y_tensor_tuple.template at<i>()(y_indices) =
209  type_convert<YDataType>(v_acc_tuple.template at<i>());
210  });
211  };
212 
213  make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
214 }
215 
216 template <typename XDataType,
217  typename ComputeDataType,
218  typename YDataType,
219  typename YRefTuple,
220  typename ReduceOps, // Expected type: ck_tile::tuple<...> containing reduce operations
221  typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to
222  // keep
223  typename ReduceDims, // Expected type: ck_tile::sequence<...> containing dimension indices
224  // to reduce
225  typename ElementWiseOps,
226  typename AccElementWiseOps,
227  typename InterBlockReduceOps>
229  YRefTuple& y_tensor_tuple,
230  ReduceOps reduce_ops,
231  KeptDim kept_dim,
232  ReduceDims reduce_dims,
233  ElementWiseOps elementwise_ops,
234  AccElementWiseOps accumulator_ops,
235  InterBlockReduceOps inter_block_reduce_ops,
236  ck_tile::index_t num_blocks)
237 {
238  const auto& x_lengths = x_tensor.mDesc.get_lengths();
239 
240  // Calculate total kept elements (product of all kept dimension lengths)
241  index_t total_kept_elements = 1;
242  static_for<0, kept_dim.size(), 1>{}(
243  [&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
244 
245  // Calculate total reduce elements (product of all reduce dimension lengths)
246  index_t total_reduce_elements = 1;
247  static_for<0, reduce_dims.size(), 1>{}(
248  [&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
249 
250  // Initialize output tensors
251  static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
252  auto& y_tensor = y_tensor_tuple.template at<i>();
253  for(auto& val : y_tensor.mData)
254  {
255  val = inter_block_reduce_ops.template at<i>().template GetIdentityValue<YDataType>();
256  }
257  });
258 
259  auto f = [&](auto linear_kept_idx) {
260  // Convert linear kept index to multi-dimensional kept indices
261  std::vector<index_t> kept_indices(kept_dim.size());
262  index_t temp_kept = linear_kept_idx;
263  static_for<0, kept_dim.size(), 1>{}([&](auto i) {
264  constexpr auto dim_idx = kept_dim.size() - 1 - i;
265  constexpr auto dim = kept_dim.at(dim_idx);
266  const auto len = x_lengths[dim];
267  kept_indices[dim_idx] = temp_kept % len;
268  temp_kept /= len;
269  });
270 
271  // Calculate output tensor index using kept indices
272  std::vector<std::size_t> y_indices(kept_dim.size());
273  static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
274 
275  const auto max_element_per_block = (total_reduce_elements + num_blocks - 1) / num_blocks;
276 
277  for(index_t block_id = 0; block_id < num_blocks; ++block_id)
278  {
279  // Initialize accumulators for each reduction operation for the current block
280  auto v_acc_tuple = ck_tile::generate_tuple(
281  [&](auto i) {
282  return reduce_ops.template at<i>().template GetIdentityValue<ComputeDataType>();
283  },
284  number<reduce_ops.size()>{});
285 
286  const index_t element_offset = block_id * max_element_per_block;
287  const index_t element_end =
288  std::min(element_offset + max_element_per_block, total_reduce_elements);
289 
290  for(index_t linear_reduce_idx = element_offset; linear_reduce_idx < element_end;
291  ++linear_reduce_idx)
292  {
293  // Convert linear reduce index to multi-dimensional reduce indices
294  std::vector<index_t> reduce_indices(reduce_dims.size());
295  index_t temp_reduce = linear_reduce_idx;
296  static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
297  constexpr auto dim_idx = reduce_dims.size() - 1 - i;
298  constexpr auto dim = reduce_dims.at(dim_idx);
299  const auto len = x_lengths[dim];
300  reduce_indices[dim_idx] = temp_reduce % len;
301  temp_reduce /= len;
302  });
303 
304  // Build full input tensor indices by combining kept and reduce indices
305  std::vector<std::size_t> full_indices(x_lengths.size(), 0);
306  static_for<0, kept_dim.size(), 1>{}(
307  [&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
308  static_for<0, reduce_dims.size(), 1>{}(
309  [&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
310 
311  // Access input tensor element
312  const auto v_a_in = type_convert<ComputeDataType>(x_tensor(full_indices));
313 
314  // Apply each reduction operation
315  static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
316  auto v_a = v_a_in;
317  // Apply element-wise operation before reduction
318  elementwise_ops.at(i)(v_a, v_a);
319 
320  v_acc_tuple.template at<i>() =
321  reduce_ops.template at<i>()(v_acc_tuple.template at<i>(), v_a);
322  });
323  }
324 
325  static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
326  // Apply accumulator element-wise operation after reduction
327  accumulator_ops.at(i)(v_acc_tuple.template at<i>(), v_acc_tuple.template at<i>());
328 
329  // Update the output tensor with the partial result from this block
330  auto& y_tensor = y_tensor_tuple.template at<i>();
331  auto& y_val = y_tensor(y_indices);
332  y_val = inter_block_reduce_ops.template at<i>()(
333  y_val, type_convert<YDataType>(v_acc_tuple.template at<i>()));
334  });
335  }
336  };
337 
338  make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
339 }
340 
341 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_multiple_reduce(const HostTensor< XDataType > &x_tensor, YRefTuple &y_tensor_tuple, ReduceOps reduce_ops, KeptDim kept_dim, ReduceDims reduce_dims, ElementWiseOps elementwise_ops, AccElementWiseOps accumulator_ops)
Definition: reference_reduce.hpp:124
CK_TILE_HOST void reference_multiple_reduce_multiblock(const HostTensor< XDataType > &x_tensor, YRefTuple &y_tensor_tuple, ReduceOps reduce_ops, KeptDim kept_dim, ReduceDims reduce_dims, ElementWiseOps elementwise_ops, AccElementWiseOps accumulator_ops, InterBlockReduceOps inter_block_reduce_ops, ck_tile::index_t num_blocks)
Definition: reference_reduce.hpp:228
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_HOST void reference_reduce(const HostTensor< XDataType > &x_m_n, HostTensor< YDataType > &y_m, ReduceOp reduce_op)
Definition: reference_reduce.hpp:15
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
Descriptor mDesc
Definition: host_tensor.hpp:801
Definition: integral_constant.hpp:13
Definition: functional.hpp:43