30 template <
typename DataType, 
typename IndexType = index_t>
 
   43     assert(dim == -1 || dim < 
rank);
 
   49     assert(k <= topk_src_len);
 
   50     assert(
static_cast<size_t>(k) == y_values.
get_length(topk_dim) &&
 
   51            static_cast<size_t>(k) == y_indices.
get_length(topk_dim));
 
   56     auto f = [&](
auto i_element) {
 
   57         std::vector<size_t> topk_coord = [&](){
 
   58             std::vector<size_t> t_(
rank, 0);
 
   61                 if(i == topk_dim)          
continue; 
 
   62                 t_[i] = r % x_len[i];      r = r / x_len[i];
 
   67         using elem_t = std::pair<DataType, IndexType>;
 
   68         std::vector<elem_t> q = [&](){
 
   69             std::vector<elem_t> t_(topk_src_len);
 
   70             for(
index_t i = 0; i < topk_src_len; i++) {
 
   71                 auto c_ = topk_coord;  c_[topk_dim] = i;
 
   72                 t_[i].first = x(c_);   t_[i].second = i;
 
   79             std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
 
   80             [](
const elem_t& lhs, 
const elem_t& rhs) -> 
bool { return lhs.first > rhs.first; });
 
   82                 std::sort(q.begin(), q.begin() + k - 1,
 
   83                 [](
const elem_t& lhs, 
const elem_t& rhs) -> 
bool { return lhs.first > rhs.first; });
 
   86             std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
 
   87             [](
const elem_t& lhs, 
const elem_t& rhs) -> 
bool { return lhs.first < rhs.first; });
 
   89                 std::sort(q.begin(), q.begin() + k - 1,
 
   90                 [](
const elem_t& lhs, 
const elem_t& rhs) -> 
bool { return lhs.first < rhs.first; });
 
   95         for(
index_t i = 0; i < k; i++) {
 
   96             auto c_ = topk_coord;  c_[topk_dim] = i;
 
   97             y_values(c_) = q[i].first;  y_indices(c_) = q[i].second;
 
  106 template <
typename DataType, 
typename IndexType = index_t>
 
  114     index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim;
 
  115     assert(target_dim < lens.size());
 
  116     assert(k <= lens[target_dim]);
 
  117     lens[target_dim] = k;
 
  121     reference_topk<DataType, IndexType>(x, y_values, y_indices, k, dim, largest, sorted);
 
#define CK_TILE_HOST
Definition: config.hpp:40
 
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
 
int32_t index_t
Definition: integer.hpp:9
 
CK_TILE_HOST void reference_topk(const HostTensor< DataType > &x, HostTensor< DataType > &y_values, HostTensor< IndexType > &y_indices, index_t k, index_t dim=-1, bool largest=true, bool sorted=true)
Definition: reference_topk.hpp:31
 
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
 
Definition: host_tensor.hpp:336
 
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
 
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396
 
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
 
std::size_t get_element_size() const
Definition: host_tensor.hpp:398