/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_transpose.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_transpose.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_transpose.hpp Source File
reference_batched_transpose.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <thread>
9 
10 namespace ck_tile {
11 
12 template <typename Type>
15  std::string layout_in = "NCHW",
16  std::string layout_out = "NHWC")
17 {
18  const int N = x.mDesc.get_lengths()[0];
19 
20  auto f = [&](auto batch) {
21  if(layout_in == "NCHW" && layout_out == "NHWC")
22  {
23  const int C = x.mDesc.get_lengths()[1];
24  const int H = x.mDesc.get_lengths()[2];
25  const int W = x.mDesc.get_lengths()[3];
26  for(int c = 0; c < C; ++c)
27  {
28  for(int h = 0; h < H; ++h)
29  {
30  for(int w = 0; w < W; ++w)
31  {
32  Type v_x = x(batch, c, h, w);
33  y(batch, h, w, c) = v_x;
34  }
35  }
36  }
37  }
38  else if(layout_in == "NHWC" && layout_out == "NCHW")
39  {
40  const int H = x.mDesc.get_lengths()[1];
41  const int W = x.mDesc.get_lengths()[2];
42  const int C = x.mDesc.get_lengths()[3];
43  for(int h = 0; h < H; ++h)
44  {
45  for(int w = 0; w < W; ++w)
46  {
47  for(int c = 0; c < C; ++c)
48  {
49  Type v_x = x(batch, h, w, c);
50  y(batch, c, h, w) = v_x;
51  }
52  }
53  }
54  }
55  };
56 
57  make_ParallelTensorFunctor(f, N)(std::thread::hardware_concurrency());
58 }
59 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_HOST void reference_batched_transpose(const HostTensor< Type > &x, HostTensor< Type > &y, std::string layout_in="NCHW", std::string layout_out="NHWC")
Definition: reference_batched_transpose.hpp:13
Type
Type of JSON value.
Definition: rapidjson.h:729
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
Descriptor mDesc
Definition: host_tensor.hpp:800