/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_permute.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_permute.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_permute.hpp Source File
reference_permute.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <thread>
9 #include <numeric>
10 #include <functional>
11 
12 namespace ck_tile {
13 
14 /*
15  this will do permute + contiguous like functionality in pytorch
16 */
17 template <typename DataType>
18 CK_TILE_HOST void
19 reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> perm)
20 {
21  const auto x_len = x.mDesc.get_lengths();
22  const auto y_len = y.mDesc.get_lengths();
23  assert(x_len.size() == y_len.size());
24  index_t rank = x_len.size();
25  const auto x_elm = std::accumulate(x_len.begin(), x_len.end(), 1, std::multiplies<index_t>());
26  const auto y_elm = std::accumulate(y_len.begin(), y_len.end(), 1, std::multiplies<index_t>());
27  assert(x_elm == y_elm);
28  (void)y_elm;
29 
30  auto f = [&](auto i_element) {
31  std::vector<size_t> y_coord = [&]() {
32  std::vector<size_t> tmp(rank, 0);
33  size_t r = i_element;
34  for(index_t i = rank - 1; i >= 0; i--)
35  {
36  tmp[i] = r % y_len[i];
37  r = r / y_len[i];
38  }
39  return tmp;
40  }();
41 
42  std::vector<size_t> x_coord = [&]() {
43  std::vector<size_t> tmp(rank, 0);
44  for(index_t i = 0; i < rank; i++)
45  {
46  tmp[perm[i]] = y_coord[i];
47  }
48  return tmp;
49  }();
50 
51  // do permute
52  y(y_coord) = x(x_coord);
53  };
54 
55  make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
56 }
57 
58 template <typename DataType>
59 CK_TILE_HOST auto reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
60 {
61  auto x_shape = x.get_lengths();
62  ck_tile::index_t rank = perm.size();
63  std::vector<ck_tile::index_t> y_shape = [&]() {
64  std::vector<ck_tile::index_t> tmp(rank, 0);
65  for(int i = 0; i < static_cast<int>(rank); i++)
66  {
67  tmp[i] = x_shape[perm[i]];
68  }
69  return tmp;
70  }();
71 
72  HostTensor<DataType> y(y_shape);
73  reference_permute(x, y, perm);
74  return y;
75 }
76 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition: reference_permute.hpp:19
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Descriptor mDesc
Definition: host_tensor.hpp:800