12 template <
typename XDataType, 
typename ComputeDataType, 
typename YDataType, 
typename ReduceOp>
 
   16     auto f = [&](
auto m) {
 
   19         ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
 
   21         for(
int n = 0; n < N; ++n)
 
   23             const ComputeDataType v_a = type_convert<ComputeDataType>(x_m_n(m, n));
 
   25             v_acc = reduce_op(v_acc, v_a);
 
   28         y_m(m) = ck_tile::type_convert<YDataType>(v_acc);
 
   37     typename ComputeDataType,
 
   47                                    ReduceDims reduce_dims)
 
   52     index_t total_kept_elements = 1;
 
   54         [&](
auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
 
   57     index_t total_reduce_elements = 1;
 
   59         [&](
auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
 
   61     auto f = [&](
auto linear_kept_idx) {
 
   62         ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
 
   65         std::vector<index_t> kept_indices(kept_dim.size());
 
   66         index_t temp_kept = linear_kept_idx;
 
   67         static_for<0, kept_dim.size(), 1>{}([&](
auto i) {
 
   68             constexpr 
auto dim_idx = kept_dim.size() - 1 - i;
 
   69             constexpr 
auto dim     = kept_dim.at(dim_idx);
 
   70             const auto len         = x_lengths[dim];
 
   71             kept_indices[dim_idx]  = temp_kept % len;
 
   75         for(
index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
 
   78             std::vector<index_t> reduce_indices(reduce_dims.size());
 
   79             index_t temp_reduce = reduce_idx;
 
   80             static_for<0, reduce_dims.size(), 1>{}([&](
auto i) {
 
   81                 constexpr 
auto dim_idx  = reduce_dims.size() - 1 - i;
 
   82                 constexpr 
auto dim      = reduce_dims.at(dim_idx);
 
   83                 const auto len          = x_lengths[dim];
 
   84                 reduce_indices[dim_idx] = temp_reduce % len;
 
   89             std::vector<std::size_t> full_indices(x_lengths.size(), 0);
 
   91                 [&](
auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
 
   93                 [&](
auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
 
   96             const auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
 
   98             v_acc = reduce_op(v_acc, v_a);
 
  103         std::vector<std::size_t> y_indices(kept_dim.size());
 
  104         static_for<0, kept_dim.size(), 1>{}([&](
auto i) { y_indices[i] = kept_indices[i]; });
 
  106         y_tensor(y_indices) = type_convert<YDataType>(v_acc);
 
#define CK_TILE_HOST
Definition: config.hpp:40
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
 
int32_t index_t
Definition: integer.hpp:9
 
CK_TILE_HOST void reference_reduce(const HostTensor< XDataType > &x_m_n, HostTensor< YDataType > &y_m, ReduceOp reduce_op)
Definition: reference_reduce.hpp:14
 
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
 
Definition: host_tensor.hpp:336
 
Descriptor mDesc
Definition: host_tensor.hpp:800
 
Definition: functional.hpp:43