14 template <
typename InDataType,
 
   15           typename ComputeDataType,
 
   17           typename IndexDataType,
 
   21           bool OutputIndex = 
false>
 
   49     auto f = [&](
auto n, 
auto ho, 
auto wo, 
auto c) {
 
   50         ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
 
   52         IndexDataType current_index = 0; 
 
   64                 if(hi >= 0 && hi < H && wi >= 0 && wi < W)
 
   66                     const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
 
   68                     if constexpr(OutputIndex)
 
   72                         v_acc                    = reduce_op(v_acc, v_in, changed);
 
   75                             current_index = flat_index;
 
   80                         v_acc = reduce_op(v_acc, v_in);
 
   87         output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
 
   89         if constexpr(OutputIndex)
 
   91             output_index(n, ho, wo, c) = current_index;
 
   99 template <
typename InDataType,
 
  100           typename ComputeDataType,
 
  101           typename OutDataType,
 
  102           typename IndexDataType,
 
  104           typename TensorShape,
 
  105           typename WindowShape,
 
  106           bool OutputIndex = 
false>
 
  140     auto f = [&](
auto n, 
auto do_, 
auto ho, 
auto wo, 
auto c) {
 
  141         ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
 
  143         IndexDataType current_index = 0; 
 
  160                     if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W)
 
  162                         const ComputeDataType v_in =
 
  163                             type_convert<ComputeDataType>(input(n, di, hi, wi, c));
 
  165                         if constexpr(OutputIndex)
 
  167                             IndexDataType flat_index =
 
  169                             bool changed = 
false;
 
  170                             v_acc        = reduce_op(v_acc, v_in, changed);
 
  173                                 current_index = flat_index;
 
  178                             v_acc = reduce_op(v_acc, v_in);
 
  186         output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
 
  188         if constexpr(OutputIndex)
 
  191             output_index(n, do_, ho, wo, c) = current_index;
 
#define CK_TILE_HOST
Definition: config.hpp:40
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
 
CK_TILE_HOST void reference_pool2d(const HostTensor< InDataType > &input, HostTensor< OutDataType > &output, HostTensor< IndexDataType > &output_index, PoolKernelArgs< TensorShape, WindowShape > kargs, ReduceOp reduce_op)
Definition: reference_pool.hpp:22
 
int32_t index_t
Definition: integer.hpp:9
 
CK_TILE_HOST void reference_pool3d(const HostTensor< InDataType > &input, HostTensor< OutDataType > &output, HostTensor< IndexDataType > &output_index, PoolKernelArgs< TensorShape, WindowShape > kargs, ReduceOp reduce_op)
Definition: reference_pool.hpp:107
 
Definition: host_tensor.hpp:336
 
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition: host_tensor.hpp:531
 
Kernel arguments for pooling operations.
Definition: pool_kernel.hpp:63
 
TensorShape output_shape
Definition: pool_kernel.hpp:68
 
WindowShape window_lengths
Definition: pool_kernel.hpp:71
 
WindowShape window_dilations
Definition: pool_kernel.hpp:73
 
WindowShape input_left_pads
Definition: pool_kernel.hpp:74
 
TensorShape input_shape
Definition: pool_kernel.hpp:67
 
WindowShape window_strides
Definition: pool_kernel.hpp:72
 
Definition: integral_constant.hpp:13