/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/block/block_reduce2d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/block/block_reduce2d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/block/block_reduce2d.hpp Source File
block_reduce2d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 // BlockReduce2d implements a hierarchical 2D reduction operator that reduces data along the second
11 // dimension using a user-specified reduction function.
12 //
13 // The reduction is performed in a three-stage hierarchical approach:
14 //
15 // STAGE 1: Thread-level reduction (BlockReduce2d)
16 // ===============================================
17 // - Each thread processes multiple elements from the input tensor within its assigned data
18 // partition
19 // - Reduction is performed locally within each thread by iterating over assigned elements
20 // - ReducePacksPerXDim controls how many elements sweep_tile processes in one iteration per
21 // dimension
22 // (e.g., {1,1} = 1 element at a time from each dimension, {2,4} = 2 from dim0, 4 from dim1)
23 // - Results are accumulated into a thread-local output tensor stored in registers
24 // - The output tensor distribution is derived from the input tensor's distribution using
25 // make_reduce_tile_distribution_encoding() to handle dimension reduction
26 //
27 // STAGE 2: Warp-level reduction (BlockReduce2dSync)
28 // ================================================
29 // - Performs inter-thread reduction within each warp
30 // - Uses warp shuffle operations to exchange data between threads in the same warp
31 // - Implements a tree-reduction pattern with power-of-2 stages
32 // - Only reduces along dimensions that map to lane IDs within the warp
33 //
34 // STAGE 3: Cross-warp reduction (BlockReduce2dCrossWarpSync)
35 // ========================================================
36 // - Performs reduction across multiple warps within the same thread block
37 // - Uses shared memory (LDS) to facilitate data exchange between warps
38 // - Each warp's lane-0 thread stores its partial results to shared memory
39 // - All threads participate in loading and reducing data from shared memory
40 // - Implements block-level synchronization to ensure memory consistency
41 
42 // BlockReduce2d: Thread-level reduction (Stage 1)
43 template <typename Problem_, typename Policy_ = void>
45 {
46  // Thread-level reduction implementation
48  using XDataType = typename Problem::XDataType;
49  using ComputeDataType = typename Problem::ComputeDataType;
50 
52 
53  template <
54  typename XDistributedTensor_,
55  typename YDistributedTensor_,
56  typename ReduceFunc,
57  typename ReducePacksPerXDim =
58  uniform_sequence_gen_t<2, 1>> // {1,1} = process 1 element at a time from each dimension
59  CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
60  YDistributedTensor_& y_tensor,
61  const ReduceFunc& reduce_func,
62  ReducePacksPerXDim = {})
63  {
64  sweep_tile<XDistributedTensor_>(
65  [&](auto... idx_) {
66  constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
67  y_tensor(idx_0) = reduce_func(
68  y_tensor(idx_0), ck_tile::type_convert<ComputeDataType>(x_tensor[idx_])...);
69  },
70  ReducePacksPerXDim{});
71 
72 #if 0
73  constexpr auto I0 = number<0>{};
74  constexpr auto I1 = number<1>{};
75  constexpr auto spans = XDistributedTensor_::get_distributed_spans();
76 
77  // FIXME: hard coded to reduce 2nd axis
78  sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
79  constexpr auto y_dstr_idx = make_tuple(dstr_idx_i0);
80 
81  auto y = y_tensor[y_dstr_idx];
82 
83  sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
84  constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
85  const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
86 
87  y = reduce_func(y, x);
88  });
89 
90  y_tensor(y_dstr_idx) = y;
91  });
92 #endif
93  }
94 
95  template <typename XDistributedTensor_>
97  {
98  static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
99 
100  // FIXME: hard coded to reduce 2nd axis
101  constexpr auto reduce_dims = sequence<1>{};
102 
103  constexpr auto dstr =
105  XDistributedTensor_::get_tile_distribution()
106  .get_static_tile_distribution_encoding(),
107  reduce_dims));
108 
109  auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
110 
111  return tensor;
112  }
113 
114  // uniform_sequence_gen_t<NSize, Value> generates sequence of NSize elements filled with Value
115  // e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4}
116  template <typename XDistributedTensor_,
117  typename ReduceFunc,
118  typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
119  CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor,
120  const ComputeDataType& reduce_init,
121  const ReduceFunc& reduce_func,
122  ReducePacksPerXDim = {})
123  {
124  auto y_tensor = MakeYBlockTile<XDistributedTensor_>();
125  set_tile(y_tensor, reduce_init);
126  (*this)(x_tensor, y_tensor, reduce_func, ReducePacksPerXDim{});
127 
128  return y_tensor;
129  }
130 };
131 
132 // BlockReduce2dSync: Warp-level reduction (Stage 2)
133 template <typename Problem_, typename Policy_ = void>
135 {
137 
138  template <typename YDistributedTensor_, typename ReduceFunc>
139  CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
140  {
141  using Dstr = typename YDistributedTensor_::StaticTileDistribution;
142  using DstrEncode = typename Dstr::DstrEncode;
143  using DstrEncodeDetail = typename DstrEncode::detail;
144 
145  constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
146  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
147 
148  constexpr index_t idim_p_lane = NDimP - 1;
149 
150  // const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
151  // const auto rs_idx =
152  // y_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
153 
154  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
155 
156  // loop over thread data
157  static_for<0, thread_buf_size, 1>{}([&](auto i) {
158  auto v_local = y_tensor.get_thread_buffer()[i];
159 
160  // cross-lane reduce for replication
161  // only reduce on R dimension correspond to lane
162  // (lane id maps to this R dimension)
163  static_for<0, NDimR, 1>{}([&](auto idim_r) {
164  // FIXME: nasty to use does_p_own_r_
165  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
166  {
167  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
168 
169  constexpr index_t lid_over_rid_derivative =
170  DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
171 
172  static_assert(is_power_of_two_integer(r_length),
173  "wrong! only support power of 2 reduction");
174 
175  constexpr index_t nstage = integer_log2_floor(r_length);
176 
177  // reduction sweep forward
178  static_for<0, nstage, 1>{}([&](auto istage) {
179  // xor
180  index_t src_lane =
181  (__lane_id()) ^
182  (number<lid_over_rid_derivative << istage.value>{}.value);
183 
184  // pull data from remote lane
185  const auto v_remote = warp_shuffle(v_local, src_lane);
186  v_local = reduce_func(v_local, v_remote);
187  });
188  }
189  });
190 
191  // TODO - Do we need to broadcast to other lane?
192  y_tensor.get_thread_buffer()(i) = v_local;
193  });
194  }
195 };
196 
197 // BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3)
198 template <typename Problem_, typename Policy_ = void>
200 {
202  using BlockShape = typename Problem::BlockShape;
203 
204  template <typename YDistributedTensor_>
206  {
207  constexpr index_t num_reduce_warps = [&]() {
208  using Dstr = typename YDistributedTensor_::StaticTileDistribution;
209  using DstrEncode = typename Dstr::DstrEncode;
210  using DstrEncodeDetail = typename DstrEncode::detail;
211 
212  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
213 
214  constexpr index_t idim_p_warp = 0;
215 
216  index_t len_ = 1;
217  static_for<0, NDimR, 1>{}([&](auto idim_r) {
218  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
219  {
220  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
221  len_ *= r_length;
222  }
223  });
224  return len_;
225  }();
226  return num_reduce_warps;
227  }
228 
229  // return in byte
230  template <typename YDistributedTensor_>
232  {
233  using DataType = typename YDistributedTensor_::DataType;
234  // constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
235 
236  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
237 
238  // we need to store all data from every wave into smem
239  // e.g. 2x2 reduce along N
240  // -------------> reduce N
241  // | w0 | w1 | ___> | w01 |
242  // | w2 | w3 | | w23 |
243  //
244  // -> store data from every wave into LDS
245  //
246  //
247  // -------------> reduce N
248  // | w0 | w1 | w2 | w3 | -----> | w0123 |
249  //
250  // -> also store data from every wave into LDS
251  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
252  return num_warps * thread_buf_size * sizeof(DataType);
253  }
254 
255  template <typename YDistributedTensor_, typename ReduceFunc>
256  CK_TILE_DEVICE void
257  operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
258  {
259  using DataType = typename YDistributedTensor_::DataType;
260 
261  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
262 
263  DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
264  const index_t lane_id = get_lane_id();
265  const index_t warp_id = get_warp_id();
266  constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
267  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
268  const index_t smem_offset = warp_id;
269 
270  // skip if nonthing to do
271  if constexpr(num_reduce_warps == 1)
272  return;
273 
274  // store into smem only for lane-0 within one warp
275  if(lane_id == 0)
276  {
277  static_for<0, thread_buf_size, 1>{}([&](auto i) {
278  smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
279  });
280  }
281  block_sync_lds();
282 
283  // load from smem. here we let everythread to do compute :)
284  index_t local_warp_id = warp_id / num_reduce_warps;
285  index_t local_smem_os = local_warp_id * num_reduce_warps;
286  DataType all_scratch[thread_buf_size * num_reduce_warps];
287  static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
288  static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
289  all_scratch[i_0 * num_reduce_warps + i_1] =
290  smem_ptr[i_0 * num_warps + local_smem_os + i_1];
291  });
292  });
293  block_sync_lds(); // TODO: we don't need sync here
294 
295  static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
296  // TODO: use descriptor for this
297  auto v_local = all_scratch[i_0 * num_reduce_warps];
298 
299  // further reduce mean/var
300  static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
301  constexpr auto i_1 = number<i_1_n1 + 1>{};
302  const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
303  v_local = reduce_func(v_local, v_remote);
304  });
305 
306  y_tensor.get_thread_buffer()(i_0) = v_local;
307  });
308  }
309 };
310 
311 template <typename Problem_, typename Policy_ = void>
313 {
315  using BlockShape = typename Problem::BlockShape;
316 
317  template <typename YDistributedTensor_>
319  {
320  constexpr index_t num_reduce_warps = [&]() {
321  using Dstr = typename YDistributedTensor_::StaticTileDistribution;
322  using DstrEncode = typename Dstr::DstrEncode;
323  using DstrEncodeDetail = typename DstrEncode::detail;
324 
325  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
326 
327  constexpr index_t idim_p_warp = 0;
328 
329  index_t len_ = 1;
330  static_for<0, NDimR, 1>{}([&](auto idim_r) {
331  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
332  {
333  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
334  len_ *= r_length;
335  }
336  });
337  return len_;
338  }();
339  return num_reduce_warps;
340  }
341 
342  // return in byte
343  template <typename YDistributedTensor_>
345  {
346  using DataType = typename YDistributedTensor_::DataType;
347  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
348 
349  // we need to store all data from every wave into smem
350  // e.g. 2x2 reduce along N
351  // -------------> reduce N
352  // | w0 | w1 | ___> | w01 |
353  // | w2 | w3 | | w23 |
354  //
355  // -> store data from every wave into LDS
356  //
357  //
358  // -------------> reduce N
359  // | w0 | w1 | w2 | w3 | -----> | w0123 |
360  //
361  // -> also store data from every wave into LDS
362  constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
363  return num_warps * thread_buf_size * sizeof(DataType);
364  }
365 
366  template <typename YDistributedTensor_, typename ReduceFunc>
367  CK_TILE_DEVICE void
368  operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
369  {
370  using Dstr = typename YDistributedTensor_::StaticTileDistribution;
371  using DstrEncode = typename Dstr::DstrEncode;
372  using DstrEncodeDetail = typename DstrEncode::detail;
373  using DataType = typename YDistributedTensor_::DataType;
374 
375  constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
376  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
377 
378  constexpr index_t idim_p_lane = NDimP - 1;
379  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
380 
381  DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
382  const index_t lane_id = get_lane_id();
383  const index_t warp_id = get_warp_id();
384 
385  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
386  constexpr index_t num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
387 
388  if constexpr(num_reduce_warps == 1)
389  return;
390 
391  // Each warp's lane 0 writes its partial results to shared memory
392  const index_t smem_offset = warp_id;
393  if(lane_id == 0)
394  {
395  static_for<0, thread_buf_size, 1>{}([&](auto i) {
396  // Store the i-th element of this warp's thread_buffer into SMEM
397  smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
398  });
399  }
400  block_sync_lds();
401 
402  // We let each warp holds a duplication to do reduction.
403  static_for<0, thread_buf_size, 1>{}([&](auto i) {
404  DataType v = 0;
405  if(lane_id < num_reduce_warps)
406  {
407  v = smem_ptr[lane_id + i * num_warps];
408  }
409 
410  // cross-lane reduce for replication
411  // only reduce on R dimension correspond to lane
412  // (lane id maps to this R dimension)
413  static_for<0, NDimR, 1>{}([&](auto idim_r) {
414  // FIXME: nasty to use does_p_own_r_
415  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
416  {
417  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
418 
419  constexpr index_t lid_over_rid_derivative =
420  DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
421 
422  static_assert(is_power_of_two_integer(r_length),
423  "wrong! only support power of 2 reduction");
424 
425  constexpr index_t nstage = integer_log2_floor(r_length);
426 
427  // reduction sweep forward
428  static_for<0, nstage, 1>{}([&](auto istage) {
429  // pull data from remote lane
430  const auto o =
431  __shfl_xor(v, number<lid_over_rid_derivative << istage.value>{}.value);
432 
433  // reduce
434  v = reduce_func(v, o);
435  });
436  }
437  });
438 
439  y_tensor.get_thread_buffer()(i) = v;
440  });
441  }
442 };
443 
444 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:190
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:762
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:78
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:455
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1023
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
Definition: block_reduce2d.hpp:200
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:231
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:201
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:257
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:202
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:205
Definition: block_reduce2d.hpp:45
constexpr CK_TILE_DEVICE BlockReduce2d()
Definition: block_reduce2d.hpp:51
typename Problem::ComputeDataType ComputeDataType
Definition: block_reduce2d.hpp:49
static CK_TILE_DEVICE auto MakeYBlockTile()
Definition: block_reduce2d.hpp:96
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:59
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:47
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, const ComputeDataType &reduce_init, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:119
typename Problem::XDataType XDataType
Definition: block_reduce2d.hpp:48
Definition: block_reduce2d.hpp:135
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:139
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:136
Definition: block_reduce2d.hpp:313
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:318
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:344
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:315
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:368
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:314
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:49
Definition: functional.hpp:43