include/ck_tile/host/reference/reference_topk.hpp Source File

include/ck_tile/host/reference/reference_topk.hpp Source File#

Composable Kernel: include/ck_tile/host/reference/reference_topk.hpp Source File
reference_topk.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <thread>
9 #include <numeric>
10 #include <functional>
11 #include <utility>
12 #include <algorithm>
13 
14 namespace ck_tile {
15 
16 /*
17  similiar to torch.topk()
18  x (Tensor) – the input tensor.
19  k (int) – the k in “top-k”
20  dim (int, optional) – the dimension to sort along
21  largest (bool, optional) – largest or smallest elements
22  sorted (bool, optional) – elements in sorted order or not
23 
24  output:
25  y_values
26  y_indices
27 
28  https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TopKImpl.h
29 */
30 template <typename DataType, typename IndexType = index_t>
32  HostTensor<DataType>& y_values,
33  HostTensor<IndexType>& y_indices,
34  index_t k,
35  index_t dim = -1,
36  bool largest = true,
37  bool sorted = true)
38 {
39  // rank must be the same
41  assert(rank == y_values.get_num_of_dimension());
42  assert(rank == y_indices.get_num_of_dimension());
43  assert(dim == -1 || dim < rank);
44 
45  index_t topk_dim = dim == -1 ? (rank - 1) : dim;
46  index_t topk_src_len = x.get_length(topk_dim);
47  auto x_len = x.get_lengths();
48 
49  assert(k <= topk_src_len);
50  assert(k == y_values.get_length(topk_dim) && k == y_indices.get_length(topk_dim));
51 
52  index_t n_parallel = x.get_element_size() / topk_src_len;
53 
54  // clang-format off
55  auto f = [&](auto i_element) {
56  std::vector<size_t> topk_coord = [&](){
57  std::vector<size_t> t_(rank, 0);
58  size_t r = i_element;
59  for(index_t i = rank - 1; i >= 0; i--) {
60  if(i == topk_dim) continue; // topk dim should be zero
61  t_[i] = r % x_len[i]; r = r / x_len[i];
62  }
63  return t_;
64  }();
65 
66  using elem_t = std::pair<DataType, IndexType>;
67  std::vector<elem_t> q = [&](){
68  std::vector<elem_t> t_(topk_src_len);
69  for(index_t i = 0; i < topk_src_len; i++) {
70  auto c_ = topk_coord; c_[topk_dim] = i;
71  t_[i].first = x(c_); t_[i].second = i;
72  }
73  return t_;
74  }();
75 
76  // run topk
77  if(largest) {
78  std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
79  [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
80  if(sorted) {
81  std::sort(q.begin(), q.begin() + k - 1,
82  [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
83  }
84  } else {
85  std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
86  [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
87  if(sorted) {
88  std::sort(q.begin(), q.begin() + k - 1,
89  [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
90  }
91  }
92 
93  // write out
94  for(index_t i = 0; i < k; i++) {
95  auto c_ = topk_coord; c_[topk_dim] = i;
96  y_values(c_) = q[i].first; y_indices(c_) = q[i].second;
97  }
98  };
99  // clang-format on
100 
101  make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
102 }
103 
104 // TODO: if using this method, the return tensor would be dense(no stride)
105 template <typename DataType, typename IndexType = index_t>
107  index_t k,
108  index_t dim = -1,
109  bool largest = true,
110  bool sorted = true)
111 {
112  auto lens = x.get_lengths();
113  index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim;
114  assert(target_dim < lens.size());
115  assert(k <= lens[target_dim]);
116  lens[target_dim] = k;
117  HostTensor<DataType> y_values(lens);
118  HostTensor<IndexType> y_indices(lens);
119 
120  reference_topk<DataType, IndexType>(x, y_values, y_indices, k, dim, largest, sorted);
121 
122  return ck_tile::make_tuple(y_values, y_indices);
123 }
124 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:39
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:272
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_topk(const HostTensor< DataType > &x, HostTensor< DataType > &y_values, HostTensor< IndexType > &y_indices, index_t k, index_t dim=-1, bool largest=true, bool sorted=true)
Definition: reference_topk.hpp:31
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
Definition: host_tensor.hpp:279
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:333
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:339
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:331
std::size_t get_element_size() const
Definition: host_tensor.hpp:341