13 template <
typename XDataType,
typename ComputeDataType,
typename YDataType,
typename ReduceOp>
17 auto f = [&](
auto m) {
20 ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
22 for(
int n = 0; n < N; ++n)
24 const ComputeDataType v_a = type_convert<ComputeDataType>(x_m_n(m, n));
26 v_acc = reduce_op(v_acc, v_a);
29 y_m(m) = ck_tile::type_convert<YDataType>(v_acc);
38 typename ComputeDataType,
48 ReduceDims reduce_dims)
53 index_t total_kept_elements = 1;
55 [&](
auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
58 index_t total_reduce_elements = 1;
60 [&](
auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
62 auto f = [&](
auto linear_kept_idx) {
63 ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
66 std::vector<index_t> kept_indices(kept_dim.size());
67 index_t temp_kept = linear_kept_idx;
68 static_for<0, kept_dim.size(), 1>{}([&](
auto i) {
69 constexpr
auto dim_idx = kept_dim.size() - 1 - i;
70 constexpr
auto dim = kept_dim.at(dim_idx);
71 const auto len = x_lengths[dim];
72 kept_indices[dim_idx] = temp_kept % len;
76 for(
index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
79 std::vector<index_t> reduce_indices(reduce_dims.size());
80 index_t temp_reduce = reduce_idx;
81 static_for<0, reduce_dims.size(), 1>{}([&](
auto i) {
82 constexpr
auto dim_idx = reduce_dims.size() - 1 - i;
83 constexpr
auto dim = reduce_dims.at(dim_idx);
84 const auto len = x_lengths[dim];
85 reduce_indices[dim_idx] = temp_reduce % len;
90 std::vector<std::size_t> full_indices(x_lengths.size(), 0);
92 [&](
auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
94 [&](
auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
97 const auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
99 v_acc = reduce_op(v_acc, v_a);
104 std::vector<std::size_t> y_indices(kept_dim.size());
105 static_for<0, kept_dim.size(), 1>{}([&](
auto i) { y_indices[i] = kept_indices[i]; });
107 y_tensor(y_indices) = type_convert<YDataType>(v_acc);
113 template <
typename XDataType,
114 typename ComputeDataType,
122 typename ElementWiseOps,
123 typename AccElementWiseOps>
125 YRefTuple& y_tensor_tuple,
126 ReduceOps reduce_ops,
128 ReduceDims reduce_dims,
129 ElementWiseOps elementwise_ops,
130 AccElementWiseOps accumulator_ops)
135 index_t total_kept_elements = 1;
137 [&](
auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
140 index_t total_reduce_elements = 1;
142 [&](
auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
144 auto f = [&](
auto linear_kept_idx) {
148 return reduce_ops.template at<i>().template GetIdentityValue<ComputeDataType>();
150 number<reduce_ops.size()>{});
153 std::vector<index_t> kept_indices(kept_dim.size());
154 index_t temp_kept = linear_kept_idx;
155 static_for<0, kept_dim.size(), 1>{}([&](
auto i) {
156 constexpr
auto dim_idx = kept_dim.size() - 1 - i;
157 constexpr
auto dim = kept_dim.at(dim_idx);
158 const auto len = x_lengths[dim];
159 kept_indices[dim_idx] = temp_kept % len;
163 for(
index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
166 std::vector<index_t> reduce_indices(reduce_dims.size());
167 index_t temp_reduce = reduce_idx;
168 static_for<0, reduce_dims.size(), 1>{}([&](
auto i) {
169 constexpr
auto dim_idx = reduce_dims.size() - 1 - i;
170 constexpr
auto dim = reduce_dims.at(dim_idx);
171 const auto len = x_lengths[dim];
172 reduce_indices[dim_idx] = temp_reduce % len;
177 std::vector<std::size_t> full_indices(x_lengths.size(), 0);
179 [&](
auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
181 [&](
auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
184 auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
187 static_for<0, reduce_ops.size(), 1>{}([&](
auto i) {
189 elementwise_ops.at(i)(v_a, v_a);
191 v_acc_tuple.template at<i>() =
192 reduce_ops.template at<i>()(v_acc_tuple.template at<i>(), v_a);
196 static_for<0, reduce_ops.size(), 1>{}([&](
auto i) {
198 accumulator_ops.at(i)(v_acc_tuple.template at<i>(), v_acc_tuple.template at<i>());
203 std::vector<std::size_t> y_indices(kept_dim.size());
204 static_for<0, kept_dim.size(), 1>{}([&](
auto i) { y_indices[i] = kept_indices[i]; });
207 static_for<0, reduce_ops.size(), 1>{}([&](
auto i) {
208 y_tensor_tuple.template at<i>()(y_indices) =
209 type_convert<YDataType>(v_acc_tuple.template at<i>());
216 template <
typename XDataType,
217 typename ComputeDataType,
225 typename ElementWiseOps,
226 typename AccElementWiseOps,
227 typename InterBlockReduceOps>
229 YRefTuple& y_tensor_tuple,
230 ReduceOps reduce_ops,
232 ReduceDims reduce_dims,
233 ElementWiseOps elementwise_ops,
234 AccElementWiseOps accumulator_ops,
235 InterBlockReduceOps inter_block_reduce_ops,
241 index_t total_kept_elements = 1;
243 [&](
auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
246 index_t total_reduce_elements = 1;
248 [&](
auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
251 static_for<0, reduce_ops.size(), 1>{}([&](
auto i) {
252 auto& y_tensor = y_tensor_tuple.template at<i>();
253 for(
auto& val : y_tensor.mData)
255 val = inter_block_reduce_ops.template at<i>().template GetIdentityValue<YDataType>();
259 auto f = [&](
auto linear_kept_idx) {
261 std::vector<index_t> kept_indices(kept_dim.size());
262 index_t temp_kept = linear_kept_idx;
263 static_for<0, kept_dim.size(), 1>{}([&](
auto i) {
264 constexpr
auto dim_idx = kept_dim.size() - 1 - i;
265 constexpr
auto dim = kept_dim.at(dim_idx);
266 const auto len = x_lengths[dim];
267 kept_indices[dim_idx] = temp_kept % len;
272 std::vector<std::size_t> y_indices(kept_dim.size());
273 static_for<0, kept_dim.size(), 1>{}([&](
auto i) { y_indices[i] = kept_indices[i]; });
275 const auto max_element_per_block = (total_reduce_elements + num_blocks - 1) / num_blocks;
277 for(
index_t block_id = 0; block_id < num_blocks; ++block_id)
282 return reduce_ops.template at<i>().template GetIdentityValue<ComputeDataType>();
284 number<reduce_ops.size()>{});
286 const index_t element_offset = block_id * max_element_per_block;
288 std::min(element_offset + max_element_per_block, total_reduce_elements);
290 for(
index_t linear_reduce_idx = element_offset; linear_reduce_idx < element_end;
294 std::vector<index_t> reduce_indices(reduce_dims.size());
295 index_t temp_reduce = linear_reduce_idx;
296 static_for<0, reduce_dims.size(), 1>{}([&](
auto i) {
297 constexpr
auto dim_idx = reduce_dims.size() - 1 - i;
298 constexpr
auto dim = reduce_dims.at(dim_idx);
299 const auto len = x_lengths[dim];
300 reduce_indices[dim_idx] = temp_reduce % len;
305 std::vector<std::size_t> full_indices(x_lengths.size(), 0);
307 [&](
auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
309 [&](
auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
312 const auto v_a_in = type_convert<ComputeDataType>(x_tensor(full_indices));
315 static_for<0, reduce_ops.size(), 1>{}([&](
auto i) {
318 elementwise_ops.at(i)(v_a, v_a);
320 v_acc_tuple.template at<i>() =
321 reduce_ops.template at<i>()(v_acc_tuple.template at<i>(), v_a);
325 static_for<0, reduce_ops.size(), 1>{}([&](
auto i) {
327 accumulator_ops.at(i)(v_acc_tuple.template at<i>(), v_acc_tuple.template at<i>());
330 auto& y_tensor = y_tensor_tuple.template at<i>();
331 auto& y_val = y_tensor(y_indices);
332 y_val = inter_block_reduce_ops.template at<i>()(
333 y_val, type_convert<YDataType>(v_acc_tuple.template at<i>()));
#define CK_TILE_HOST
Definition: config.hpp:44
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_multiple_reduce(const HostTensor< XDataType > &x_tensor, YRefTuple &y_tensor_tuple, ReduceOps reduce_ops, KeptDim kept_dim, ReduceDims reduce_dims, ElementWiseOps elementwise_ops, AccElementWiseOps accumulator_ops)
Definition: reference_reduce.hpp:124
CK_TILE_HOST void reference_multiple_reduce_multiblock(const HostTensor< XDataType > &x_tensor, YRefTuple &y_tensor_tuple, ReduceOps reduce_ops, KeptDim kept_dim, ReduceDims reduce_dims, ElementWiseOps elementwise_ops, AccElementWiseOps accumulator_ops, InterBlockReduceOps inter_block_reduce_ops, ck_tile::index_t num_blocks)
Definition: reference_reduce.hpp:228
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_HOST void reference_reduce(const HostTensor< XDataType > &x_m_n, HostTensor< YDataType > &y_m, ReduceOp reduce_op)
Definition: reference_reduce.hpp:15
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
Descriptor mDesc
Definition: host_tensor.hpp:801
Definition: integral_constant.hpp:13
Definition: functional.hpp:43