22 std::vector<ck_tile::long_index_t> conv_strides,
23 std::vector<ck_tile::long_index_t> conv_dilations,
24 std::vector<ck_tile::long_index_t> in_left_pads,
25 std::vector<ck_tile::long_index_t>)
37 throw std::runtime_error(
"wrong! inconsistent dimension");
40 if constexpr(NDimSpatial == 1)
42 auto func = [&](
auto g,
auto n,
auto c,
auto wi) {
49 for(std::size_t x = 0; x < X; ++x)
55 if(w_tmp % conv_strides[0] == 0)
60 if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
62 for(std::size_t k = 0; k < K; ++k)
64 OutDataType v_out = output(g, n, k, wo);
65 WeiDataType v_wei = weight(g, k, c, x);
66 v_acc += ck_tile::type_convert<float>(v_out) *
67 ck_tile::type_convert<float>(v_wei);
72 InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
73 input(g, n, c, wi) = v_acc_converted;
80 input.
get_lengths()[3])(std::thread::hardware_concurrency());
82 else if constexpr(NDimSpatial == 2)
84 auto func = [&](
auto g,
auto n,
auto c,
auto hi,
auto wi) {
94 for(std::size_t y = 0; y < Y; ++y)
99 if(h_tmp % conv_strides[0] == 0)
103 if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
105 for(std::size_t x = 0; x < X; ++x)
110 if(w_tmp % conv_strides[1] == 0)
115 if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
117 for(std::size_t k = 0; k < K; ++k)
119 OutDataType v_out = output(g, n, k, ho, wo);
120 WeiDataType v_wei = weight(g, k, c, y, x);
121 v_acc += ck_tile::type_convert<float>(v_out) *
122 ck_tile::type_convert<float>(v_wei);
130 InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
131 input(g, n, c, hi, wi) = v_acc_converted;
139 input.
get_lengths()[4])(std::thread::hardware_concurrency());
141 else if constexpr(NDimSpatial == 3)
143 auto func = [&](
auto g,
auto n,
auto c,
auto di,
auto hi,
auto wi) {
155 for(std::size_t z = 0; z < Z; ++z)
160 if(d_tmp % conv_strides[0] == 0)
164 if(do_ >= 0 && ck_tile::type_convert<std::size_t>(do_) < Do)
166 for(std::size_t y = 0; y < Y; ++y)
171 if(h_tmp % conv_strides[1] == 0)
175 if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
177 for(std::size_t x = 0; x < X; ++x)
185 if(w_tmp % conv_strides[2] == 0)
191 ck_tile::type_convert<std::size_t>(wo) < Wo)
193 for(std::size_t k = 0; k < K; ++k)
196 output(g, n, k, do_, ho, wo);
197 WeiDataType v_wei = weight(g, k, c, z, y, x);
198 v_acc += ck_tile::type_convert<float>(v_out) *
199 ck_tile::type_convert<float>(v_wei);
210 InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
211 input(g, n, c, di, hi, wi) = v_acc_converted;
220 input.
get_lengths()[5])(std::thread::hardware_concurrency());
224 throw std::runtime_error(
225 "Ref_conv_bwd_data: number of dimensions must be between 1 and 3.");
#define CK_TILE_HOST
Definition: config.hpp:44
#define PRIu64
Definition: inttypes.h:143
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor< InDataType > &input, const HostTensor< WeiDataType > &weight, const HostTensor< OutDataType > &output, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, std::vector< ck_tile::long_index_t >)
Definition: reference_grouped_conv_bwd_data.hpp:19
int64_t long_index_t
Definition: integer.hpp:11
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396