include/ck_tile/ops/reduce/block/block_reduce.hpp Source File

include/ck_tile/ops/reduce/block/block_reduce.hpp Source File#

Composable Kernel: include/ck_tile/ops/reduce/block/block_reduce.hpp Source File
block_reduce.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include <tuple>
8 
9 // This file is not support cross warp reduce
10 namespace ck_tile {
11 
12 /*
13  * TODO: block_tile_reduce_sync() currently has a limitation
14  * Y dim must have at least one dim not been reduced
15  */
16 // synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
17 template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
18 CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
19  const ReduceFunc& reduce_func,
21 {
22  using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
23  using DstrEncode = typename Dstr::DstrEncode;
24  using DstrEncodeDetail = typename DstrEncode::detail;
25 
26  constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
27  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
28 
29  constexpr index_t idim_p_lane = NDimP - 1;
30 
31  const auto ps_idx = detail::get_partition_index(acc_tensor.get_tile_distribution());
32  const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
33 
34  constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
35 
36  // loop over thread data
37  static_for<0, thread_buf_size, 1>{}([&](auto i) {
38  auto v_local = acc_tensor.get_thread_buffer()[i];
39 
40  // cross-lane reduce for replication
41  // only reduce on R dimension correspond to lane
42  // (lane id maps to this R dimension)
43  static_for<0, NDimR, 1>{}([&](auto idim_r) {
44  // FIXME: nasty to use does_p_own_r_
45  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
46  {
47  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
48 
49  constexpr index_t lid_over_rid_derivative =
50  DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
51 
52  static_assert(is_power_of_two_integer(r_length),
53  "wrong! only support power of 2 reduction");
54 
55  constexpr index_t nstage = integer_log2_floor(r_length);
56 
57  // reduction sweep forward
58  static_for<0, nstage, 1>{}([&](auto istage) {
59  constexpr index_t lid_delta =
60  lid_over_rid_derivative * (1 << (nstage - istage - 1));
61 
62  // pull data from remote lane
63  const auto v_remote = warp_shuffle_down(v_local, lid_delta);
64 
65  // reduce
66  v_local = reduce_func(v_local, v_remote);
67  });
68  }
69  });
70 
71  if constexpr(WithBroadcast)
72  {
73  // cross-lane broadcast for replication
74  // only broadcast on R dimension correspond to lane
75  // (lane id maps to this R dimension)
76  static_for<0, NDimR, 1>{}([&](auto idim_r) {
77  // FIXME: nasty to use does_p_own_r_
78  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
79  {
80  const index_t r_id = rs_idx[idim_r];
81 
82  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
83 
84  constexpr index_t lid_over_rid_derivative =
85  DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r];
86 
87  static_assert(is_power_of_two_integer(r_length),
88  "wrong! only support power of 2 reduction");
89 
90  constexpr index_t nstage = integer_log2_floor(r_length);
91 
92  // broadcast sweep backward
93  static_for<0, nstage, 1>{}([&](auto istage) {
94  // do I hold reduced data?
95  const bool do_i_hold_reduced_data = r_id < (1 << istage);
96 
97  constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage);
98 
99  // pull data from remote lane
100  const auto v_remote = warp_shuffle_up(v_local, lid_delta);
101 
102  // decide whether to update local data with remote data
103  v_local = do_i_hold_reduced_data ? v_local : v_remote;
104  });
105  }
106  });
107  }
108 
109  acc_tensor.get_thread_buffer()(i) = v_local;
110  });
111 }
112 
113 /*
114  * this version is faster, using xor to do reduce, no need broadcast anymore
115  * TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
116  */
117 template <typename AccDistributedTensor_, typename ReduceFunc>
118 CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
119  const ReduceFunc& reduce_func)
120 {
121  using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
122  using DstrEncode = typename Dstr::DstrEncode;
123  using DstrEncodeDetail = typename DstrEncode::detail;
124 
125  constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
126  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
127 
128  constexpr index_t idim_p_lane = NDimP - 1;
129 
130  constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
131 
132  // loop over thread data
133  static_for<0, thread_buf_size, 1>{}([&](auto i) {
134  auto v_local = acc_tensor.get_thread_buffer()[i];
135 
136  // cross-lane reduce for replication
137  // only reduce on R dimension correspond to lane
138  // (lane id maps to this R dimension)
139  static_for<0, NDimR, 1>{}([&](auto idim_r) {
140  // FIXME: nasty to use does_p_own_r_
141  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
142  {
143  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
144 
145  constexpr index_t lid_over_rid_derivative =
146  DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
147 
148  static_assert(is_power_of_two_integer(r_length),
149  "wrong! only support power of 2 reduction");
150 
151  constexpr index_t nstage = integer_log2_floor(r_length);
152 
153  // reduction sweep forward
154  static_for<0, nstage, 1>{}([&](auto istage) {
155  // xor
156  index_t src_lane =
157  __lane_id() ^ (number<lid_over_rid_derivative << istage.value>{}.value);
158 
159  // pull data from remote lane
160  const auto v_remote = warp_shuffle(v_local, src_lane);
161 
162  // reduce
163  v_local = reduce_func(v_local, v_remote);
164  });
165  }
166  });
167 
168  acc_tensor.get_thread_buffer()(i) = v_local;
169  });
170 }
171 
172 // FIXME: this is for 2D to 1D reduce only, need to support n-D
173 template <typename AccDistributedTensor_,
174  typename InDistributedTensor_,
175  index_t... InReduceDims,
176  typename ReduceFunc>
177 CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
178  const InDistributedTensor_& in_tensor,
180  const ReduceFunc& reduce_func)
181 {
182  constexpr auto I0 = number<0>{};
183  constexpr auto I1 = number<1>{};
184 
185 #if 0
186  constexpr auto in_reduce_dims = sequence<InReduceDims...>{};
187 
188  constexpr index_t ndim_in = InDistributedTensor_::get_num_of_dimension();
189  constexpr index_t ndim_in_reduce = in_reduce_dims.size();
190  constexpr index_t ndim_in_free = ndim_in - ndim_in_reduce;
191 
192  constexpr auto in_free_dims_arr = [&] {
193  array<bool, ndim_free> is_free_dims{true};
194 
195  for(index_t i = 0; i < ndim_reduce; i++)
196  {
197  is_free_dims(in_reduce_dims[i]) = false;
198  }
199 
200  array<index_t, ndim_free> in_free_dims{-1};
201 
202  index_t cnt = 0;
203 
204  for(index_t i = 0; i < ndim_in; i++)
205  {
206  if(is_free_dims[i])
207  {
208  in_free_dims(cnt) = i;
209 
210  cnt++
211  }
212  }
213 
214  return is_free_dims;
215  }();
216 
217  constexpr auto in_free_dims = TO_SEQUENCE(is_free_dims_arr, ndim_in_free);
218 #else
219 
220  constexpr auto spans = InDistributedTensor_::get_distributed_spans();
221 
222  // in-thread reduction
223  // FIXME: hard coded to be 2D to 1D reduction
224  sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
225  constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0);
226 
227  auto acc = acc_tensor[acc_dstr_idx];
228 
229  // FIXME
230  sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
231  constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
232 
233  const auto in = in_tensor[in_dstr_idx];
234 
235  acc = reduce_func(acc, in);
236  });
237 
238  acc_tensor(acc_dstr_idx) = acc;
239  });
240 #endif
241 }
242 
243 /*
244  * TODO: block_tile_reduce() currently has a limitation
245  * Y dim must have at least one dim not been reduced
246  */
247 template <typename AccDataType_,
248  typename InDistributedTensor_,
249  index_t... InReduceDims,
250  typename ReduceFunc,
251  typename InDataType_>
252 CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
253  sequence<InReduceDims...> in_reduce_dims,
254  const ReduceFunc& reduce_func,
255  const InDataType_& reduce_init)
256 {
257  using InDataType = typename InDistributedTensor_::DataType;
258  using AccDataType = remove_cvref_t<AccDataType_>;
259 
260  static_assert(std::is_same_v<InDataType, remove_cvref_t<InDataType_>>, "wrong!");
261 
262  // declare acc_tensor
263  constexpr auto acc_dstr =
265  InDistributedTensor_::get_tile_distribution().get_static_tile_distribution_encoding(),
267 
268  auto acc_tensor = make_static_distributed_tensor<AccDataType>(acc_dstr);
269 
270  // init acc_tensor
271  tile_elementwise_inout([&](auto& acc) { acc = type_convert<AccDataType>(reduce_init); },
272  acc_tensor);
273 
274  // warp reduce
275  block_tile_reduce(acc_tensor, in_tensor, in_reduce_dims, reduce_func);
276 
277  return acc_tensor;
278 }
279 
280 // this version only support 2D->1D reduce (reduce-dim=seq<0, 1>)
281 // this version only support in/acc/out datatypes are the same
282 // this version will call thread/warp+sync in one function call
283 //
284 template <typename InDistributedTensor_>
286 {
288  using InDataType = typename InDistributedTensor::DataType;
289 
291  : t(t_), reduce_init(reduce_init_)
292  {
293  }
294 
295  CK_TILE_HOST_DEVICE constexpr auto MakeDstBlockTile() const
296  {
297  using ReduceDim = sequence<1>; // hard coded
298  constexpr auto acc_dstr =
300  InDistributedTensor::get_tile_distribution()
301  .get_static_tile_distribution_encoding(),
302  ReduceDim{}));
303 
304  auto dst_ = make_static_distributed_tensor<InDataType>(acc_dstr);
305  // init acc_tensor
306  tile_elementwise_inout([&](auto& x_) { x_ = type_convert<InDataType>(reduce_init); }, dst_);
307  return dst_;
308  }
309 
310  // return number of pixels each lane need to reduce
312  {
313  constexpr auto spans = InDistributedTensor::get_distributed_spans();
314  }
315 
316  // Here ReducePacksPerXDim is not the same meaning as that in static_uford/sweep_tile_uspan
317  // this is number of packs along the X-dim. We need to compute the Unpacks along the Y dim
318  // internally
319  // For simplicity, we just support along the row dimension, ReducePacksPerXDim is always 2
320  // element , and the first element is always ignored For simplicity, will always try from
321  // right-to-left to find alone which Y dim to split
322  template <typename ReduceFunc,
323  typename ReduceSyncFunc,
324  typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
325  CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func,
326  const ReduceSyncFunc& reduce_sync_func,
327  ReducePacksPerXDim = {}) const
328  {
329  constexpr auto spans = InDistributedTensor::get_distributed_spans();
330 
331  constexpr auto row_y_unpacks = [&]() {
332  constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{};
333  constexpr auto row_y_size =
334  reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{});
335  constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
336 
337  static_assert(row_y_size % row_y_packs == 0);
338 
339  constexpr auto row_y_slice_size = row_y_size / row_y_packs;
340 
341  constexpr auto slice_info = slice_sequence(row_y_lengths, number<row_y_slice_size>{});
342  constexpr auto unpacks = slice_info[number<1>{}];
343  return unpacks;
344  }();
345 
346  auto acc_tensor = MakeDstBlockTile();
347 
348  // in-thread reduction
349  // FIXME: hard coded to be 2D to 1D reduction
350  sweep_tile_span(spans[number<0>{}], [&](auto dstr_idx_i0) {
351  constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0);
352 
353  auto acc = acc_tensor[acc_dstr_idx];
354 
356  spans[number<1>{}],
357  [&](auto... dstr_idx_i1) {
358  acc = reduce_func(acc, t[make_tuple(dstr_idx_i0, dstr_idx_i1)]...);
359  },
360  row_y_unpacks);
361 
362  acc_tensor(acc_dstr_idx) = acc;
363  });
364 
365  // TODO: always use xor to do cross-lane reduce
366  block_tile_reduce_xor_sync(acc_tensor, reduce_sync_func);
367 
368  return acc_tensor;
369  }
370 
371  template <typename ReduceFunc>
372  CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func) const
373  {
374  return operator()(reduce_func, reduce_func);
375  }
376 
379 };
380 
381 // deduction guide
382 template <typename T>
383 CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&)->BlockReduce2D<T>;
384 
385 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE_EXTERN
Definition: config.hpp:43
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:725
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce.hpp:118
CK_TILE_DEVICE T warp_shuffle_up(const T &v_local, uint32_t lane_delta)
Definition: utility.hpp:31
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F &f, Unpacks={})
Definition: sweep_tile.hpp:37
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:63
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={})
Definition: block_reduce.hpp:18
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
CK_TILE_DEVICE T warp_shuffle_down(const T &v_local, uint32_t lane_delta)
Definition: utility.hpp:48
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:973
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:455
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition: block_reduce.hpp:177
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr auto slice_sequence(Seq, number< SliceSize >, Mask=typename uniform_sequence_gen< Seq::size(), 1 >::type{})
Definition: sequence.hpp:1225
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T &, const typename T::DataType &) -> BlockReduce2D< T >
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1017
constexpr bool is_same_v
Definition: type.hpp:283
Definition: block_reduce.hpp:286
remove_cvref_t< InDistributedTensor_ > InDistributedTensor
Definition: block_reduce.hpp:287
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc &reduce_func, const ReduceSyncFunc &reduce_sync_func, ReducePacksPerXDim={}) const
Definition: block_reduce.hpp:325
InDataType reduce_init
Definition: block_reduce.hpp:378
constexpr CK_TILE_HOST_DEVICE auto MakeDstBlockTile() const
Definition: block_reduce.hpp:295
InDistributedTensor t
Definition: block_reduce.hpp:377
typename InDistributedTensor::DataType InDataType
Definition: block_reduce.hpp:288
CK_TILE_HOST_DEVICE BlockReduce2D(const InDistributedTensor &t_, const InDataType &reduce_init_)
Definition: block_reduce.hpp:290
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc &reduce_func) const
Definition: block_reduce.hpp:372
constexpr CK_TILE_HOST_DEVICE auto get_reduce_length_y() const
Definition: block_reduce.hpp:311
Definition: array.hpp:24
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:52
Definition: functional.hpp:43
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10