21                                                   std::vector<ck_tile::long_index_t> conv_strides,
 
   22                                                   std::vector<ck_tile::long_index_t> conv_dilations,
 
   23                                                   std::vector<ck_tile::long_index_t> in_left_pads,
 
   24                                                   std::vector<ck_tile::long_index_t>)
 
   36         throw std::runtime_error(
"wrong! inconsistent dimension");
 
   39     if constexpr(NDimSpatial == 1)
 
   41         auto func = [&](
auto g, 
auto n, 
auto c, 
auto wi) {
 
   48             for(std::size_t x = 0; x < X; ++x)
 
   54                 if(w_tmp % conv_strides[0] == 0)
 
   59                     if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
 
   61                         for(std::size_t k = 0; k < K; ++k)
 
   63                             OutDataType v_out = output(g, n, k, wo);
 
   64                             WeiDataType v_wei = weight(g, k, c, x);
 
   65                             v_acc += ck_tile::type_convert<float>(v_out) *
 
   66                                      ck_tile::type_convert<float>(v_wei);
 
   71             InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
 
   72             input(g, n, c, wi)         = v_acc_converted;
 
   79                                    input.
get_lengths()[3])(std::thread::hardware_concurrency());
 
   81     else if constexpr(NDimSpatial == 2)
 
   83         auto func = [&](
auto g, 
auto n, 
auto c, 
auto hi, 
auto wi) {
 
   93             for(std::size_t y = 0; y < Y; ++y)
 
   98                 if(h_tmp % conv_strides[0] == 0)
 
  102                     if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
 
  104                         for(std::size_t x = 0; x < X; ++x)
 
  109                             if(w_tmp % conv_strides[1] == 0)
 
  114                                 if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
 
  116                                     for(std::size_t k = 0; k < K; ++k)
 
  118                                         OutDataType v_out = output(g, n, k, ho, wo);
 
  119                                         WeiDataType v_wei = weight(g, k, c, y, x);
 
  120                                         v_acc += ck_tile::type_convert<float>(v_out) *
 
  121                                                  ck_tile::type_convert<float>(v_wei);
 
  129             InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
 
  130             input(g, n, c, hi, wi)     = v_acc_converted;
 
  138                                    input.
get_lengths()[4])(std::thread::hardware_concurrency());
 
  140     else if constexpr(NDimSpatial == 3)
 
  142         auto func = [&](
auto g, 
auto n, 
auto c, 
auto di, 
auto hi, 
auto wi) {
 
  154             for(std::size_t z = 0; z < Z; ++z)
 
  159                 if(d_tmp % conv_strides[0] == 0)
 
  163                     if(do_ >= 0 && ck_tile::type_convert<std::size_t>(do_) < Do)
 
  165                         for(std::size_t y = 0; y < Y; ++y)
 
  170                             if(h_tmp % conv_strides[1] == 0)
 
  174                                 if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
 
  176                                     for(std::size_t x = 0; x < X; ++x)
 
  184                                         if(w_tmp % conv_strides[2] == 0)
 
  190                                                ck_tile::type_convert<std::size_t>(wo) < Wo)
 
  192                                                 for(std::size_t k = 0; k < K; ++k)
 
  195                                                         output(g, n, k, do_, ho, wo);
 
  196                                                     WeiDataType v_wei = weight(g, k, c, z, y, x);
 
  197                                                     v_acc += ck_tile::type_convert<float>(v_out) *
 
  198                                                              ck_tile::type_convert<float>(v_wei);
 
  209             InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
 
  210             input(g, n, c, di, hi, wi) = v_acc_converted;
 
  219                                    input.
get_lengths()[5])(std::thread::hardware_concurrency());
 
  223         throw std::runtime_error(
 
  224             "Ref_conv_bwd_data: number of dimensions must be between 1 and 3.");
 
#define CK_TILE_HOST
Definition: config.hpp:40
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
 
int32_t index_t
Definition: integer.hpp:9
 
CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor< InDataType > &input, const HostTensor< WeiDataType > &weight, const HostTensor< OutDataType > &output, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, std::vector< ck_tile::long_index_t >)
Definition: reference_grouped_conv_bwd_data.hpp:18
 
int64_t long_index_t
Definition: integer.hpp:11
 
Definition: host_tensor.hpp:336
 
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
 
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396