22                                   std::vector<ck_tile::long_index_t> conv_strides,
 
   23                                   std::vector<ck_tile::long_index_t> conv_dilations,
 
   24                                   std::vector<ck_tile::long_index_t> in_left_pads,
 
   25                                   std::vector<ck_tile::long_index_t>)
 
   31         throw std::runtime_error(
"wrong! inconsistent dimension");
 
   34     if constexpr(NDimSpatial == 1)
 
   36         auto func = [&](
auto g, 
auto k, 
auto c, 
auto x) {
 
   39             for(std::size_t n = 0; n < output.
get_lengths()[1]; ++n)
 
   41                 for(std::size_t wo = 0; wo < output.
get_lengths()[3]; ++wo)
 
   47                     if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.
get_lengths()[3])
 
   49                         InDataType v_in   = input(g, n, c, wi);
 
   50                         OutDataType v_out = output(g, n, k, wo);
 
   51                         v_acc += ck_tile::type_convert<float>(v_out) *
 
   52                                  ck_tile::type_convert<float>(v_in);
 
   56             OutDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
 
   57             weight(g, k, c, x)          = v_acc_converted;
 
   64                                    weight.
get_lengths()[3])(std::thread::hardware_concurrency());
 
   66     else if constexpr(NDimSpatial == 2)
 
   68         auto func = [&](
auto g, 
auto k, 
auto c, 
auto y, 
auto x) {
 
   71             for(std::size_t n = 0; n < output.
get_lengths()[1]; ++n)
 
   73                 for(std::size_t ho = 0; ho < output.
get_lengths()[3]; ++ho)
 
   79                     for(std::size_t wo = 0; wo < output.
get_lengths()[4]; ++wo)
 
   86                            ck_tile::type_convert<std::size_t>(hi) < input.
get_lengths()[3] &&
 
   88                            ck_tile::type_convert<std::size_t>(wi) < input.
get_lengths()[4])
 
   90                             InDataType v_in   = input(g, n, c, hi, wi);
 
   91                             OutDataType v_out = output(g, n, k, ho, wo);
 
   93                             v_acc += ck_tile::type_convert<float>(v_out) *
 
   94                                      ck_tile::type_convert<float>(v_in);
 
   99             WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
 
  100             weight(g, k, c, y, x)       = v_acc_converted;
 
  108                                    weight.
get_lengths()[4])(std::thread::hardware_concurrency());
 
  110     else if constexpr(NDimSpatial == 3)
 
  112         auto func = [&](
auto g, 
auto k, 
auto c, 
auto z, 
auto y, 
auto x) {
 
  115             for(std::size_t n = 0; n < output.
get_lengths()[1]; ++n)
 
  117                 for(std::size_t do_ = 0; do_ < output.
get_lengths()[3]; ++do_)
 
  122                     for(std::size_t ho = 0; ho < output.
get_lengths()[4]; ++ho)
 
  127                         for(std::size_t wo = 0; wo < output.
get_lengths()[5]; ++wo)
 
  133                                ck_tile::type_convert<std::size_t>(di) < input.
get_lengths()[3] &&
 
  135                                ck_tile::type_convert<std::size_t>(hi) < input.
get_lengths()[4] &&
 
  137                                ck_tile::type_convert<std::size_t>(wi) < input.
get_lengths()[5])
 
  139                                 InDataType v_in   = input(g, n, c, di, hi, wi);
 
  140                                 OutDataType v_out = output(g, n, k, do_, ho, wo);
 
  142                                 v_acc += ck_tile::type_convert<float>(v_out) *
 
  143                                          ck_tile::type_convert<float>(v_in);
 
  149             WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
 
  150             weight(g, k, c, z, y, x)    = v_acc_converted;
 
  159                                    weight.
get_lengths()[5])(std::thread::hardware_concurrency());
 
  163         throw std::runtime_error(
 
  164             "Ref_conv_bwd_weight: number of dimensions must be between 1 and 3.");
 
#define CK_TILE_HOST
Definition: config.hpp:40
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
 
int32_t index_t
Definition: integer.hpp:9
 
int64_t long_index_t
Definition: integer.hpp:11
 
CK_TILE_HOST void reference_grouped_conv_bwd_weight(const HostTensor< InDataType > &input, HostTensor< WeiDataType > &weight, const HostTensor< OutDataType > &output, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, std::vector< ck_tile::long_index_t >)
Definition: reference_grouped_conv_bwd_weight.hpp:19
 
Definition: host_tensor.hpp:336
 
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
 
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396