/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp Source File
reference_grouped_conv_bwd_data.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <cinttypes>
7 #include <cstdlib>
8 #include <thread>
9 
10 #include "ck_tile/core.hpp"
12 
13 namespace ck_tile {
14 
15 template <ck_tile::index_t NDimSpatial,
16  typename InDataType,
17  typename WeiDataType,
18  typename OutDataType>
20  const HostTensor<WeiDataType>& weight,
21  const HostTensor<OutDataType>& output,
22  std::vector<ck_tile::long_index_t> conv_strides,
23  std::vector<ck_tile::long_index_t> conv_dilations,
24  std::vector<ck_tile::long_index_t> in_left_pads,
25  std::vector<ck_tile::long_index_t>)
26 {
27  if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
28  weight.get_num_of_dimension() == NDimSpatial + 3 &&
29  output.get_num_of_dimension() == NDimSpatial + 3))
30  {
31 
32  printf("%" PRIu64 " %" PRIu64 " %" PRIu64,
33  input.get_num_of_dimension(),
34  weight.get_num_of_dimension(),
35  output.get_num_of_dimension());
36 
37  throw std::runtime_error("wrong! inconsistent dimension");
38  }
39 
40  if constexpr(NDimSpatial == 1)
41  {
42  auto func = [&](auto g, auto n, auto c, auto wi) {
43  std::size_t K = weight.get_lengths()[1];
44  std::size_t X = weight.get_lengths()[3];
45 
46  std::size_t Wo = output.get_lengths()[3];
47  float v_acc = 0;
48 
49  for(std::size_t x = 0; x < X; ++x)
50  {
51  auto w_tmp = static_cast<ck_tile::long_index_t>(wi) +
52  static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
53  static_cast<ck_tile::long_index_t>(x * conv_dilations[0]);
54 
55  if(w_tmp % conv_strides[0] == 0)
56  {
57  auto wo = static_cast<ck_tile::long_index_t>(w_tmp) /
58  static_cast<ck_tile::long_index_t>(conv_strides[0]);
59 
60  if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
61  {
62  for(std::size_t k = 0; k < K; ++k)
63  {
64  OutDataType v_out = output(g, n, k, wo);
65  WeiDataType v_wei = weight(g, k, c, x);
66  v_acc += ck_tile::type_convert<float>(v_out) *
67  ck_tile::type_convert<float>(v_wei);
68  }
69  }
70  }
71  }
72  InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
73  input(g, n, c, wi) = v_acc_converted;
74  };
75 
77  input.get_lengths()[0],
78  input.get_lengths()[1],
79  input.get_lengths()[2],
80  input.get_lengths()[3])(std::thread::hardware_concurrency());
81  }
82  else if constexpr(NDimSpatial == 2)
83  {
84  auto func = [&](auto g, auto n, auto c, auto hi, auto wi) {
85  std::size_t K = weight.get_lengths()[1];
86  std::size_t Y = weight.get_lengths()[3];
87  std::size_t X = weight.get_lengths()[4];
88 
89  std::size_t Ho = output.get_lengths()[3];
90  std::size_t Wo = output.get_lengths()[4];
91 
92  float v_acc = 0;
93 
94  for(std::size_t y = 0; y < Y; ++y)
95  {
96  auto h_tmp = static_cast<ck_tile::long_index_t>(hi) +
97  static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
98  static_cast<ck_tile::long_index_t>(y * conv_dilations[0]);
99  if(h_tmp % conv_strides[0] == 0)
100  {
101  auto ho = static_cast<ck_tile::long_index_t>(h_tmp) /
102  static_cast<ck_tile::long_index_t>(conv_strides[0]);
103  if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
104  {
105  for(std::size_t x = 0; x < X; ++x)
106  {
107  auto w_tmp = static_cast<ck_tile::long_index_t>(wi) +
108  static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
109  static_cast<ck_tile::long_index_t>(x * conv_dilations[1]);
110  if(w_tmp % conv_strides[1] == 0)
111  {
112  auto wo = static_cast<ck_tile::long_index_t>(w_tmp) /
113  static_cast<ck_tile::long_index_t>(conv_strides[1]);
114 
115  if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
116  {
117  for(std::size_t k = 0; k < K; ++k)
118  {
119  OutDataType v_out = output(g, n, k, ho, wo);
120  WeiDataType v_wei = weight(g, k, c, y, x);
121  v_acc += ck_tile::type_convert<float>(v_out) *
122  ck_tile::type_convert<float>(v_wei);
123  }
124  }
125  }
126  }
127  }
128  }
129  }
130  InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
131  input(g, n, c, hi, wi) = v_acc_converted;
132  };
133 
135  input.get_lengths()[0],
136  input.get_lengths()[1],
137  input.get_lengths()[2],
138  input.get_lengths()[3],
139  input.get_lengths()[4])(std::thread::hardware_concurrency());
140  }
141  else if constexpr(NDimSpatial == 3)
142  {
143  auto func = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) {
144  std::size_t K = weight.get_lengths()[1];
145  std::size_t Z = weight.get_lengths()[3];
146  std::size_t Y = weight.get_lengths()[4];
147  std::size_t X = weight.get_lengths()[5];
148 
149  std::size_t Do = output.get_lengths()[3];
150  std::size_t Ho = output.get_lengths()[4];
151  std::size_t Wo = output.get_lengths()[5];
152 
153  float v_acc = 0;
154 
155  for(std::size_t z = 0; z < Z; ++z)
156  {
157  auto d_tmp = static_cast<ck_tile::long_index_t>(di) +
158  static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
159  static_cast<ck_tile::long_index_t>(z * conv_dilations[0]);
160  if(d_tmp % conv_strides[0] == 0)
161  {
162  auto do_ = static_cast<ck_tile::long_index_t>(d_tmp) /
163  static_cast<ck_tile::long_index_t>(conv_strides[0]);
164  if(do_ >= 0 && ck_tile::type_convert<std::size_t>(do_) < Do)
165  {
166  for(std::size_t y = 0; y < Y; ++y)
167  {
168  auto h_tmp = static_cast<ck_tile::long_index_t>(hi) +
169  static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
170  static_cast<ck_tile::long_index_t>(y * conv_dilations[1]);
171  if(h_tmp % conv_strides[1] == 0)
172  {
173  auto ho = static_cast<ck_tile::long_index_t>(h_tmp) /
174  static_cast<ck_tile::long_index_t>(conv_strides[1]);
175  if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
176  {
177  for(std::size_t x = 0; x < X; ++x)
178  {
179  auto w_tmp =
180  static_cast<ck_tile::long_index_t>(wi) +
181  static_cast<ck_tile::long_index_t>(in_left_pads[2]) -
182  static_cast<ck_tile::long_index_t>(x *
183  conv_dilations[2]);
184 
185  if(w_tmp % conv_strides[2] == 0)
186  {
187  auto wo =
188  static_cast<ck_tile::long_index_t>(w_tmp) /
189  static_cast<ck_tile::long_index_t>(conv_strides[2]);
190  if(wo >= 0 &&
191  ck_tile::type_convert<std::size_t>(wo) < Wo)
192  {
193  for(std::size_t k = 0; k < K; ++k)
194  {
195  OutDataType v_out =
196  output(g, n, k, do_, ho, wo);
197  WeiDataType v_wei = weight(g, k, c, z, y, x);
198  v_acc += ck_tile::type_convert<float>(v_out) *
199  ck_tile::type_convert<float>(v_wei);
200  }
201  }
202  }
203  }
204  }
205  }
206  }
207  }
208  }
209  }
210  InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
211  input(g, n, c, di, hi, wi) = v_acc_converted;
212  };
213 
215  input.get_lengths()[0],
216  input.get_lengths()[1],
217  input.get_lengths()[2],
218  input.get_lengths()[3],
219  input.get_lengths()[4],
220  input.get_lengths()[5])(std::thread::hardware_concurrency());
221  }
222  else
223  {
224  throw std::runtime_error(
225  "Ref_conv_bwd_data: number of dimensions must be between 1 and 3.");
226  }
227 }
228 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
#define PRIu64
Definition: inttypes.h:143
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor< InDataType > &input, const HostTensor< WeiDataType > &weight, const HostTensor< OutDataType > &output, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, std::vector< ck_tile::long_index_t >)
Definition: reference_grouped_conv_bwd_data.hpp:19
int64_t long_index_t
Definition: integer.hpp:11
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396