17 template <
typename AccDistributedTensor_,
19 bool WithBroadcast =
true,
20 bool CrossWarp =
true>
22 const ReduceFunc& reduce_func,
24 bool_constant<CrossWarp> = {})
26 using Dstr =
typename AccDistributedTensor_::StaticTileDistribution;
27 using DstrEncode =
typename Dstr::DstrEncode;
28 using DstrEncodeDetail =
typename DstrEncode::detail;
30 constexpr
index_t NDimP = Dstr::get_num_of_dimension_p();
31 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
33 constexpr
index_t idim_p_lane = NDimP - 1;
36 const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
38 constexpr
index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
41 static_for<0, thread_buf_size, 1>{}([&](
auto i) {
42 auto v_local = acc_tensor.get_thread_buffer()[i];
47 static_for<0, NDimR, 1>{}([&](
auto idim_r) {
49 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
51 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
53 constexpr
index_t lid_over_rid_derivative =
54 DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
57 "wrong! only support power of 2 reduction");
62 static_for<0, nstage, 1>{}([&](
auto istage) {
63 if constexpr(CrossWarp)
66 lid_over_rid_derivative * (1 << (nstage - istage - 1));
72 v_local = reduce_func(v_local, v_remote);
79 v_local = reduce_func(v_swapped_regs.at(0), v_swapped_regs.at(1));
85 if constexpr(WithBroadcast)
90 static_for<0, NDimR, 1>{}([&](
auto idim_r) {
92 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
94 const index_t r_id = rs_idx[idim_r];
96 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
98 constexpr
index_t lid_over_rid_derivative =
99 DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r];
102 "wrong! only support power of 2 reduction");
107 static_for<0, nstage, 1>{}([&](
auto istage) {
109 const bool do_i_hold_reduced_data = r_id < (1 << istage);
111 constexpr
index_t lid_delta = lid_over_rid_derivative * (1 << istage);
117 v_local = do_i_hold_reduced_data ? v_local : v_remote;
123 acc_tensor.get_thread_buffer()(i) = v_local;
131 template <
typename AccDistributedTensor_,
typename ReduceFunc>
133 const ReduceFunc& reduce_func)
135 using Dstr =
typename AccDistributedTensor_::StaticTileDistribution;
136 using DstrEncode =
typename Dstr::DstrEncode;
137 using DstrEncodeDetail =
typename DstrEncode::detail;
139 constexpr
index_t NDimP = Dstr::get_num_of_dimension_p();
140 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
142 constexpr
index_t idim_p_lane = NDimP - 1;
144 constexpr
index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
148 auto v_local = acc_tensor.get_thread_buffer()[i];
155 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
157 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
159 constexpr
index_t lid_over_rid_derivative =
160 DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
163 "wrong! only support power of 2 reduction");
171 __lane_id() ^ (
number<lid_over_rid_derivative << istage.
value>{}.value);
177 v_local = reduce_func(v_local, v_remote);
182 acc_tensor.get_thread_buffer()(i) = v_local;
187 template <
typename AccDistributedTensor_,
188 typename InDistributedTensor_,
192 const InDistributedTensor_& in_tensor,
194 const ReduceFunc& reduce_func)
200 constexpr
auto in_reduce_dims =
sequence<InReduceDims...>{};
202 constexpr
index_t ndim_in = InDistributedTensor_::get_num_of_dimension();
203 constexpr
index_t ndim_in_reduce = in_reduce_dims.size();
204 constexpr
index_t ndim_in_free = ndim_in - ndim_in_reduce;
206 constexpr
auto in_free_dims_arr = [&] {
209 for(
index_t i = 0; i < ndim_reduce; i++)
211 is_free_dims(in_reduce_dims[i]) =
false;
218 for(
index_t i = 0; i < ndim_in; i++)
222 in_free_dims(cnt) = i;
231 constexpr
auto in_free_dims =
TO_SEQUENCE(is_free_dims_arr, ndim_in_free);
234 constexpr
auto spans = InDistributedTensor_::get_distributed_spans();
239 constexpr
auto acc_dstr_idx =
make_tuple(dstr_idx_i0);
241 auto acc = acc_tensor[acc_dstr_idx];
245 constexpr
auto in_dstr_idx =
make_tuple(dstr_idx_i0, dstr_idx_i1);
247 const auto in = in_tensor[in_dstr_idx];
249 acc = reduce_func(acc, in);
252 acc_tensor(acc_dstr_idx) = acc;
261 template <
typename AccDataType_,
262 typename InDistributedTensor_,
265 typename InDataType_>
268 const ReduceFunc& reduce_func,
269 const InDataType_& reduce_init)
271 using InDataType =
typename InDistributedTensor_::DataType;
277 constexpr
auto acc_dstr =
279 InDistributedTensor_::get_tile_distribution().get_static_tile_distribution_encoding(),
282 auto acc_tensor = make_static_distributed_tensor<AccDataType>(acc_dstr);
298 template <
typename InDistributedTensor_>
312 constexpr
auto acc_dstr =
314 InDistributedTensor::get_tile_distribution()
315 .get_static_tile_distribution_encoding(),
318 auto dst_ = make_static_distributed_tensor<InDataType>(acc_dstr);
327 constexpr
auto spans = InDistributedTensor::get_distributed_spans();
336 template <
typename ReduceFunc,
337 typename ReduceSyncFunc,
340 const ReduceSyncFunc& reduce_sync_func,
341 ReducePacksPerXDim = {})
const
343 constexpr
auto spans = InDistributedTensor::get_distributed_spans();
345 constexpr
auto row_y_unpacks = [&]() {
346 constexpr
auto row_y_lengths =
typename decltype(spans[
number<1>{}])::Impl{};
347 constexpr
auto row_y_size =
349 constexpr
auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
351 static_assert(row_y_size % row_y_packs == 0);
353 constexpr
auto row_y_slice_size = row_y_size / row_y_packs;
355 constexpr
auto slice_info =
slice_sequence(row_y_lengths, number<row_y_slice_size>{});
356 constexpr
auto unpacks = slice_info[number<1>{}];
365 constexpr
auto acc_dstr_idx =
make_tuple(dstr_idx_i0);
367 auto acc = acc_tensor[acc_dstr_idx];
371 [&](
auto... dstr_idx_i1) {
372 acc = reduce_func(acc,
t[
make_tuple(dstr_idx_i0, dstr_idx_i1)]...);
376 acc_tensor(acc_dstr_idx) = acc;
385 template <
typename ReduceFunc>
396 template <
typename T>
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE_EXTERN
Definition: config.hpp:43
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:762
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce.hpp:132
CK_TILE_DEVICE T warp_shuffle_up(const T &v_local, uint32_t lane_delta)
Definition: utility.hpp:31
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F &f, Unpacks={})
Definition: sweep_tile.hpp:37
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:78
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition: block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE T warp_shuffle_down(const T &v_local, uint32_t lane_delta)
Definition: utility.hpp:48
CK_TILE_DEVICE auto warp_shuffle_down_pair(const T &v_local)
Definition: utility.hpp:63
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:979
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:455
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition: block_reduce.hpp:191
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr auto slice_sequence(Seq, number< SliceSize >, Mask=typename uniform_sequence_gen< Seq::size(), 1 >::type{})
Definition: sequence.hpp:1246
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T &, const typename T::DataType &) -> BlockReduce2D< T >
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1023
constexpr bool is_same_v
Definition: type.hpp:283
Definition: block_reduce.hpp:300
remove_cvref_t< InDistributedTensor_ > InDistributedTensor
Definition: block_reduce.hpp:301
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc &reduce_func, const ReduceSyncFunc &reduce_sync_func, ReducePacksPerXDim={}) const
Definition: block_reduce.hpp:339
InDataType reduce_init
Definition: block_reduce.hpp:392
constexpr CK_TILE_HOST_DEVICE auto MakeDstBlockTile() const
Definition: block_reduce.hpp:309
InDistributedTensor t
Definition: block_reduce.hpp:391
typename InDistributedTensor::DataType InDataType
Definition: block_reduce.hpp:302
CK_TILE_HOST_DEVICE BlockReduce2D(const InDistributedTensor &t_, const InDataType &reduce_init_)
Definition: block_reduce.hpp:304
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc &reduce_func) const
Definition: block_reduce.hpp:386
constexpr CK_TILE_HOST_DEVICE auto get_reduce_length_y() const
Definition: block_reduce.hpp:325
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:49
Definition: functional.hpp:43
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10