/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_im2col.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_im2col.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_im2col.hpp Source File
reference_im2col.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <thread>
9 
10 namespace ck_tile {
11 
12 template <typename InDataType, typename OutDataType, index_t NDimSpatial>
14  HostTensor<OutDataType>& out_host,
15  const ck_tile::conv::ConvParam& conv_params)
16 {
17  const long_index_t G = in_host.get_lengths()[0];
18  const long_index_t N = in_host.get_lengths()[1];
19  const long_index_t C = in_host.get_lengths()[2];
20 
21  if constexpr(NDimSpatial == 1)
22  {
23  const long_index_t Wo = conv_params.output_spatial_lengths_[0];
24  auto func = [&](auto g, auto n, auto wo) {
25  long_index_t row = n * Wo + wo;
26  long_index_t column = 0;
27 
28  for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[0]; ++x)
29  {
30  auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[0]) +
31  static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[0]) -
32  static_cast<long_index_t>(conv_params.input_left_pads_[0]);
33 
34  for(long_index_t c = 0; c < C; ++c)
35  {
36  if(wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[3])
37  {
38  InDataType v_in = in_host(g, n, c, wi);
39  out_host(g, row, column) = type_convert<OutDataType>(v_in);
40  }
41  column++;
42  }
43  }
44  };
45 
46  make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
47  }
48  else if constexpr(NDimSpatial == 2)
49  {
50  const long_index_t Ho = conv_params.output_spatial_lengths_[0];
51  const long_index_t Wo = conv_params.output_spatial_lengths_[1];
52 
53  auto func = [&](auto g, auto n, auto ho, auto wo) {
54  long_index_t row = n * Ho * Wo + ho * Wo + wo;
55  long_index_t column = 0;
56 
57  for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[0]; ++y)
58  {
59  auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[0]) +
60  static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[0]) -
61  static_cast<long_index_t>(conv_params.input_left_pads_[0]);
62 
63  for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[1]; ++x)
64  {
65  auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[1]) +
66  static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[1]) -
67  static_cast<long_index_t>(conv_params.input_left_pads_[1]);
68 
69  for(long_index_t c = 0; c < C; ++c)
70  {
71 
72  if(hi >= 0 && type_convert<std::size_t>(hi) < in_host.get_lengths()[3] &&
73  wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[4])
74  {
75  InDataType v_in = in_host(g, n, c, hi, wi);
76  out_host(g, row, column) = type_convert<OutDataType>(v_in);
77  }
78  column++;
79  }
80  }
81  }
82  };
83 
84  make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
85  }
86  else if constexpr(NDimSpatial == 3)
87  {
88  const long_index_t Do = conv_params.output_spatial_lengths_[0];
89  const long_index_t Ho = conv_params.output_spatial_lengths_[1];
90  const long_index_t Wo = conv_params.output_spatial_lengths_[2];
91 
92  auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
93  long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
94  long_index_t column = 0;
95 
96  for(long_index_t z = 0; z < conv_params.filter_spatial_lengths_[0]; ++z)
97  {
98  auto di = static_cast<long_index_t>(d_o * conv_params.conv_filter_strides_[0]) +
99  static_cast<long_index_t>(z * conv_params.conv_filter_dilations_[0]) -
100  static_cast<long_index_t>(conv_params.input_left_pads_[0]);
101  for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[1]; ++y)
102  {
103  auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[1]) +
104  static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[1]) -
105  static_cast<long_index_t>(conv_params.input_left_pads_[1]);
106  for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[2]; ++x)
107  {
108  auto wi =
109  static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[2]) +
110  static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[2]) -
111  static_cast<long_index_t>(conv_params.input_left_pads_[2]);
112  for(long_index_t c = 0; c < C; ++c)
113  {
114  if(di >= 0 &&
115  type_convert<std::size_t>(di) < in_host.get_lengths()[3] &&
116  hi >= 0 &&
117  type_convert<std::size_t>(hi) < in_host.get_lengths()[4] &&
118  wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[5])
119  {
120  InDataType v_in = in_host(g, n, c, di, hi, wi);
121  out_host(g, row, column) = type_convert<OutDataType>(v_in);
122  }
123  column++;
124  }
125  }
126  }
127  }
128  };
129 
130  make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(std::thread::hardware_concurrency());
131  }
132 }
133 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_HOST void reference_im2col(const HostTensor< InDataType > &in_host, HostTensor< OutDataType > &out_host, const ck_tile::conv::ConvParam &conv_params)
Definition: reference_im2col.hpp:13
int64_t long_index_t
Definition: integer.hpp:11
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Definition: convolution_parameter.hpp:15
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134