/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp Source File
reference_grouped_conv_bwd_data.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <ck_tile::index_t NDimSpatial,
15  typename InDataType,
16  typename WeiDataType,
17  typename OutDataType>
19  const HostTensor<WeiDataType>& weight,
20  const HostTensor<OutDataType>& output,
21  std::vector<ck_tile::long_index_t> conv_strides,
22  std::vector<ck_tile::long_index_t> conv_dilations,
23  std::vector<ck_tile::long_index_t> in_left_pads,
24  std::vector<ck_tile::long_index_t>)
25 {
26  if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
27  weight.get_num_of_dimension() == NDimSpatial + 3 &&
28  output.get_num_of_dimension() == NDimSpatial + 3))
29  {
30 
31  printf("%lu %lu %lu",
32  input.get_num_of_dimension(),
33  weight.get_num_of_dimension(),
34  output.get_num_of_dimension());
35 
36  throw std::runtime_error("wrong! inconsistent dimension");
37  }
38 
39  if constexpr(NDimSpatial == 1)
40  {
41  auto func = [&](auto g, auto n, auto c, auto wi) {
42  std::size_t K = weight.get_lengths()[1];
43  std::size_t X = weight.get_lengths()[3];
44 
45  std::size_t Wo = output.get_lengths()[3];
46  float v_acc = 0;
47 
48  for(std::size_t x = 0; x < X; ++x)
49  {
50  auto w_tmp = static_cast<ck_tile::long_index_t>(wi) +
51  static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
52  static_cast<ck_tile::long_index_t>(x * conv_dilations[0]);
53 
54  if(w_tmp % conv_strides[0] == 0)
55  {
56  auto wo = static_cast<ck_tile::long_index_t>(w_tmp) /
57  static_cast<ck_tile::long_index_t>(conv_strides[0]);
58 
59  if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
60  {
61  for(std::size_t k = 0; k < K; ++k)
62  {
63  OutDataType v_out = output(g, n, k, wo);
64  WeiDataType v_wei = weight(g, k, c, x);
65  v_acc += ck_tile::type_convert<float>(v_out) *
66  ck_tile::type_convert<float>(v_wei);
67  }
68  }
69  }
70  }
71  InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
72  input(g, n, c, wi) = v_acc_converted;
73  };
74 
76  input.get_lengths()[0],
77  input.get_lengths()[1],
78  input.get_lengths()[2],
79  input.get_lengths()[3])(std::thread::hardware_concurrency());
80  }
81  else if constexpr(NDimSpatial == 2)
82  {
83  auto func = [&](auto g, auto n, auto c, auto hi, auto wi) {
84  std::size_t K = weight.get_lengths()[1];
85  std::size_t Y = weight.get_lengths()[3];
86  std::size_t X = weight.get_lengths()[4];
87 
88  std::size_t Ho = output.get_lengths()[3];
89  std::size_t Wo = output.get_lengths()[4];
90 
91  float v_acc = 0;
92 
93  for(std::size_t y = 0; y < Y; ++y)
94  {
95  auto h_tmp = static_cast<ck_tile::long_index_t>(hi) +
96  static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
97  static_cast<ck_tile::long_index_t>(y * conv_dilations[0]);
98  if(h_tmp % conv_strides[0] == 0)
99  {
100  auto ho = static_cast<ck_tile::long_index_t>(h_tmp) /
101  static_cast<ck_tile::long_index_t>(conv_strides[0]);
102  if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
103  {
104  for(std::size_t x = 0; x < X; ++x)
105  {
106  auto w_tmp = static_cast<ck_tile::long_index_t>(wi) +
107  static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
108  static_cast<ck_tile::long_index_t>(x * conv_dilations[1]);
109  if(w_tmp % conv_strides[1] == 0)
110  {
111  auto wo = static_cast<ck_tile::long_index_t>(w_tmp) /
112  static_cast<ck_tile::long_index_t>(conv_strides[1]);
113 
114  if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
115  {
116  for(std::size_t k = 0; k < K; ++k)
117  {
118  OutDataType v_out = output(g, n, k, ho, wo);
119  WeiDataType v_wei = weight(g, k, c, y, x);
120  v_acc += ck_tile::type_convert<float>(v_out) *
121  ck_tile::type_convert<float>(v_wei);
122  }
123  }
124  }
125  }
126  }
127  }
128  }
129  InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
130  input(g, n, c, hi, wi) = v_acc_converted;
131  };
132 
134  input.get_lengths()[0],
135  input.get_lengths()[1],
136  input.get_lengths()[2],
137  input.get_lengths()[3],
138  input.get_lengths()[4])(std::thread::hardware_concurrency());
139  }
140  else if constexpr(NDimSpatial == 3)
141  {
142  auto func = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) {
143  std::size_t K = weight.get_lengths()[1];
144  std::size_t Z = weight.get_lengths()[3];
145  std::size_t Y = weight.get_lengths()[4];
146  std::size_t X = weight.get_lengths()[5];
147 
148  std::size_t Do = output.get_lengths()[3];
149  std::size_t Ho = output.get_lengths()[4];
150  std::size_t Wo = output.get_lengths()[5];
151 
152  float v_acc = 0;
153 
154  for(std::size_t z = 0; z < Z; ++z)
155  {
156  auto d_tmp = static_cast<ck_tile::long_index_t>(di) +
157  static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
158  static_cast<ck_tile::long_index_t>(z * conv_dilations[0]);
159  if(d_tmp % conv_strides[0] == 0)
160  {
161  auto do_ = static_cast<ck_tile::long_index_t>(d_tmp) /
162  static_cast<ck_tile::long_index_t>(conv_strides[0]);
163  if(do_ >= 0 && ck_tile::type_convert<std::size_t>(do_) < Do)
164  {
165  for(std::size_t y = 0; y < Y; ++y)
166  {
167  auto h_tmp = static_cast<ck_tile::long_index_t>(hi) +
168  static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
169  static_cast<ck_tile::long_index_t>(y * conv_dilations[1]);
170  if(h_tmp % conv_strides[1] == 0)
171  {
172  auto ho = static_cast<ck_tile::long_index_t>(h_tmp) /
173  static_cast<ck_tile::long_index_t>(conv_strides[1]);
174  if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
175  {
176  for(std::size_t x = 0; x < X; ++x)
177  {
178  auto w_tmp =
179  static_cast<ck_tile::long_index_t>(wi) +
180  static_cast<ck_tile::long_index_t>(in_left_pads[2]) -
181  static_cast<ck_tile::long_index_t>(x *
182  conv_dilations[2]);
183 
184  if(w_tmp % conv_strides[2] == 0)
185  {
186  auto wo =
187  static_cast<ck_tile::long_index_t>(w_tmp) /
188  static_cast<ck_tile::long_index_t>(conv_strides[2]);
189  if(wo >= 0 &&
190  ck_tile::type_convert<std::size_t>(wo) < Wo)
191  {
192  for(std::size_t k = 0; k < K; ++k)
193  {
194  OutDataType v_out =
195  output(g, n, k, do_, ho, wo);
196  WeiDataType v_wei = weight(g, k, c, z, y, x);
197  v_acc += ck_tile::type_convert<float>(v_out) *
198  ck_tile::type_convert<float>(v_wei);
199  }
200  }
201  }
202  }
203  }
204  }
205  }
206  }
207  }
208  }
209  InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
210  input(g, n, c, di, hi, wi) = v_acc_converted;
211  };
212 
214  input.get_lengths()[0],
215  input.get_lengths()[1],
216  input.get_lengths()[2],
217  input.get_lengths()[3],
218  input.get_lengths()[4],
219  input.get_lengths()[5])(std::thread::hardware_concurrency());
220  }
221  else
222  {
223  throw std::runtime_error(
224  "Ref_conv_bwd_data: number of dimensions must be between 1 and 3.");
225  }
226 }
227 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor< InDataType > &input, const HostTensor< WeiDataType > &weight, const HostTensor< OutDataType > &output, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, std::vector< ck_tile::long_index_t >)
Definition: reference_grouped_conv_bwd_data.hpp:18
int64_t long_index_t
Definition: integer.hpp:11
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396