12 template <
typename XDataType,
typename ComputeDataType,
typename YDataType,
typename ReduceOp>
16 auto f = [&](
auto m) {
19 ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
21 for(
int n = 0; n < N; ++n)
23 const ComputeDataType v_a = type_convert<ComputeDataType>(x_m_n(m, n));
25 v_acc = reduce_op(v_acc, v_a);
28 y_m(m) = ck_tile::type_convert<YDataType>(v_acc);
37 typename ComputeDataType,
47 ReduceDims reduce_dims)
52 index_t total_kept_elements = 1;
54 [&](
auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
57 index_t total_reduce_elements = 1;
59 [&](
auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
61 auto f = [&](
auto linear_kept_idx) {
62 ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
65 std::vector<index_t> kept_indices(kept_dim.size());
66 index_t temp_kept = linear_kept_idx;
67 static_for<0, kept_dim.size(), 1>{}([&](
auto i) {
68 constexpr
auto dim_idx = kept_dim.size() - 1 - i;
69 constexpr
auto dim = kept_dim.at(dim_idx);
70 const auto len = x_lengths[dim];
71 kept_indices[dim_idx] = temp_kept % len;
75 for(
index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
78 std::vector<index_t> reduce_indices(reduce_dims.size());
79 index_t temp_reduce = reduce_idx;
80 static_for<0, reduce_dims.size(), 1>{}([&](
auto i) {
81 constexpr
auto dim_idx = reduce_dims.size() - 1 - i;
82 constexpr
auto dim = reduce_dims.at(dim_idx);
83 const auto len = x_lengths[dim];
84 reduce_indices[dim_idx] = temp_reduce % len;
89 std::vector<std::size_t> full_indices(x_lengths.size(), 0);
91 [&](
auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
93 [&](
auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
96 const auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
98 v_acc = reduce_op(v_acc, v_a);
103 std::vector<std::size_t> y_indices(kept_dim.size());
104 static_for<0, kept_dim.size(), 1>{}([&](
auto i) { y_indices[i] = kept_indices[i]; });
106 y_tensor(y_indices) = type_convert<YDataType>(v_acc);
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_reduce(const HostTensor< XDataType > &x_m_n, HostTensor< YDataType > &y_m, ReduceOp reduce_op)
Definition: reference_reduce.hpp:14
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
Descriptor mDesc
Definition: host_tensor.hpp:800
Definition: functional.hpp:43