/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/block/block_reduce2d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/block/block_reduce2d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/block/block_reduce2d.hpp Source File
block_reduce2d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // BlockReduce2d implements a hierarchical 2D reduction operator that reduces data along the second
12 // dimension using a user-specified reduction function.
13 //
14 // The reduction is performed in a three-stage hierarchical approach:
15 //
16 // STAGE 1: Thread-level reduction (BlockReduce2d)
17 // ===============================================
18 // - Each thread processes multiple elements from the input tensor within its assigned data
19 // partition
20 // - Reduction is performed locally within each thread by iterating over assigned elements
21 // - ReducePacksPerXDim controls how many elements sweep_tile processes in one iteration per
22 // dimension
23 // (e.g., {1,1} = 1 element at a time from each dimension, {2,4} = 2 from dim0, 4 from dim1)
24 // - Results are accumulated into a thread-local output tensor stored in registers
25 // - The output tensor distribution is derived from the input tensor's distribution using
26 // make_reduce_tile_distribution_encoding() to handle dimension reduction
27 //
28 // STAGE 2: Warp-level reduction (BlockReduce2dSync)
29 // ================================================
30 // - Performs inter-thread reduction within each warp
31 // - Uses warp shuffle operations to exchange data between threads in the same warp
32 // - Implements a tree-reduction pattern with power-of-2 stages
33 // - Only reduces along dimensions that map to lane IDs within the warp
34 //
35 // STAGE 3: Cross-warp reduction (BlockReduce2dCrossWarpSync)
36 // ========================================================
37 // - Performs reduction across multiple warps within the same thread block
38 // - Uses shared memory (LDS) to facilitate data exchange between warps
39 // - Each warp's lane-0 thread stores its partial results to shared memory
40 // - All threads participate in loading and reducing data from shared memory
41 // - Implements block-level synchronization to ensure memory consistency
42 
43 // BlockReduce2d: Thread-level reduction (Stage 1)
44 template <typename Problem_, typename Policy_ = void>
46 {
47  // Thread-level reduction implementation
49  using XDataType = typename Problem::XDataType;
50  using ComputeDataType = typename Problem::ComputeDataType;
51 
53 
54  private:
55  template <bool kProcessIndex,
56  typename XDistributedTensor_,
57  typename YDistributedTensor_,
58  typename YIndexDistributedTensor_,
59  typename ReduceFunc,
60  typename IndexCalculatorFunc,
61  typename ReducePacksPerXDim>
62  CK_TILE_DEVICE void reduce_impl(const XDistributedTensor_& x_tensor,
63  YDistributedTensor_& y_tensor,
64  YIndexDistributedTensor_& y_index_tensor,
65  const ReduceFunc& reduce_func,
66  const IndexCalculatorFunc& index_calculator,
67  ReducePacksPerXDim)
68  {
69  sweep_tile<XDistributedTensor_>(
70  [&](auto... idx_) {
71  constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
72 
73  (..., [&](auto idx) {
74  auto val = ck_tile::type_convert<ComputeDataType>(x_tensor[idx]);
75 
76  if constexpr(kProcessIndex)
77  {
78 
79  const auto x_indices = get_x_indices_from_distributed_indices(
80  XDistributedTensor_::get_tile_distribution(), idx);
81  const auto new_idx = index_calculator(x_indices);
82  auto current_idx = y_index_tensor(idx_0);
83 
84  AccumulateWithIndex{}(
85  reduce_func, y_tensor(idx_0), current_idx, val, new_idx);
86 
87  y_index_tensor(idx_0) =
88  type_convert<typename YIndexDistributedTensor_::DataType>(current_idx);
89  }
90  else
91  {
92  Accumulate{}(reduce_func, y_tensor(idx_0), val);
93  }
94  }(idx_));
95  },
96  ReducePacksPerXDim{});
97  }
98 
99  public:
100  // Overload for non-index tracking
101  template <
102  typename XDistributedTensor_,
103  typename YDistributedTensor_,
104  typename ReduceFunc,
105  typename ReducePacksPerXDim =
106  uniform_sequence_gen_t<2, 1>> // {1,1} = process 1 element at a time from each dimension
107  CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
108  YDistributedTensor_& y_tensor,
109  const ReduceFunc& reduce_func,
110  ReducePacksPerXDim = {})
111  {
112  reduce_impl<false>(
113  x_tensor,
114  y_tensor,
115  y_tensor, // dummy
116  reduce_func,
117  [](auto) { return 0; }, // dummy
118  ReducePacksPerXDim{});
119  }
120 
121  // Overload for index tracking
122  template <typename XDistributedTensor_,
123  typename YDistributedTensor_,
124  typename YIndexDistributedTensor_,
125  typename ReduceFunc,
126  typename IndexCalculatorFunc,
127  typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
128  CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
129  YDistributedTensor_& y_tensor,
130  YIndexDistributedTensor_& y_index_tensor,
131  const ReduceFunc& reduce_func,
132  const IndexCalculatorFunc& index_calculator,
133  ReducePacksPerXDim = {})
134  {
135  reduce_impl<Problem::kOutputIndex>(x_tensor,
136  y_tensor,
137  y_index_tensor,
138  reduce_func,
139  index_calculator,
140  ReducePacksPerXDim{});
141  }
142 
143 #if 0
144  constexpr auto I0 = number<0>{};
145  constexpr auto I1 = number<1>{};
146  constexpr auto spans = XDistributedTensor_::get_distributed_spans();
147 
148  // FIXME: hard coded to reduce 2nd axis
149  sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
150  constexpr auto y_dstr_idx = make_tuple(dstr_idx_i0);
151 
152  auto y = y_tensor[y_dstr_idx];
153 
154  sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
155  constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
156  const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
157 
158  y = reduce_func(y, x);
159  });
160 
161  y_tensor(y_dstr_idx) = y;
162  });
163 #endif
164 
165  template <typename XDistributedTensor_>
167  {
168  static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
169 
170  // FIXME: hard coded to reduce 2nd axis
171  constexpr auto reduce_dims = sequence<1>{};
172 
173  constexpr auto dstr =
175  XDistributedTensor_::get_tile_distribution()
176  .get_static_tile_distribution_encoding(),
177  reduce_dims));
178 
179  auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
180 
181  return tensor;
182  }
183 
184  template <typename XDistributedTensor_, typename IndexDataType = index_t>
186  {
187  static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
188 
189  // FIXME: hard coded to reduce 2nd axis
190  constexpr auto reduce_dims = sequence<1>{};
191 
192  constexpr auto dstr =
194  XDistributedTensor_::get_tile_distribution()
195  .get_static_tile_distribution_encoding(),
196  reduce_dims));
197 
198  auto tensor = make_static_distributed_tensor<IndexDataType>(dstr);
199 
200  return tensor;
201  }
202 
203  // uniform_sequence_gen_t<NSize, Value> generates sequence of NSize elements filled with Value
204  // e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4}
205  template <typename XDistributedTensor_,
206  typename ReduceFunc,
207  typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
208  CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor,
209  const ComputeDataType& reduce_init,
210  const ReduceFunc& reduce_func,
211  ReducePacksPerXDim = {})
212  {
213  auto y_tensor = MakeYBlockTile<XDistributedTensor_>();
214  set_tile(y_tensor, reduce_init);
215  (*this)(x_tensor, y_tensor, reduce_func, ReducePacksPerXDim{});
216 
217  return y_tensor;
218  }
219 };
220 
221 // BlockReduce2dSync: Warp-level reduction (Stage 2)
222 template <typename Problem_, typename Policy_ = void>
224 {
226 
227  private:
228  template <bool kProcessIndex,
229  typename YDistributedTensor_,
230  typename YIndexDistributedTensor_,
231  typename ReduceFunc>
232  CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
233  YIndexDistributedTensor_& y_index_tensor,
234  const ReduceFunc& reduce_func)
235  {
236  using Dstr = typename YDistributedTensor_::StaticTileDistribution;
237  using DstrEncode = typename Dstr::DstrEncode;
238  using DstrEncodeDetail = typename DstrEncode::detail;
239 
240  constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
241  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
242 
243  constexpr index_t idim_p_lane = NDimP - 1;
244 
245  // const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
246  // const auto rs_idx =
247  // y_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
248 
249  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
250 
251  // loop over thread data
252  static_for<0, thread_buf_size, 1>{}([&](auto i) {
253  auto v_local = y_tensor.get_thread_buffer()[i];
254 
255  using IndexDataType = typename YIndexDistributedTensor_::DataType;
256  IndexDataType idx_local{};
257 
258  if constexpr(kProcessIndex)
259  {
260  idx_local = y_index_tensor.get_thread_buffer()[i];
261  }
262 
263  // cross-lane reduce for replication
264  // only reduce on R dimension correspond to lane
265  // (lane id maps to this R dimension)
266  static_for<0, NDimR, 1>{}([&](auto idim_r) {
267  // FIXME: nasty to use does_p_own_r_
268  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
269  {
270  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
271 
272  constexpr index_t lid_over_rid_derivative =
273  DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
274 
275  static_assert(is_power_of_two_integer(r_length),
276  "wrong! only support power of 2 reduction");
277 
278  constexpr index_t nstage = integer_log2_floor(r_length);
279 
280  // reduction sweep forward
281  static_for<0, nstage, 1>{}([&](auto istage) {
282  // xor
283  index_t src_lane =
284  (__lane_id()) ^
285  (number<lid_over_rid_derivative << istage.value>{}.value);
286 
287  // pull data from remote lane
288  const auto v_remote = warp_shuffle(v_local, src_lane);
289 
290  if constexpr(kProcessIndex)
291  {
292  const auto idx_remote = warp_shuffle(idx_local, src_lane);
293 
295  reduce_func, v_local, idx_local, v_remote, idx_remote);
296  }
297  else
298  {
299  Accumulate{}(reduce_func, v_local, v_remote);
300  }
301  });
302  }
303  });
304 
305  // TODO - Do we need to broadcast to other lane?
306  y_tensor.get_thread_buffer()(i) = v_local;
307 
308  if constexpr(kProcessIndex)
309  {
310  y_index_tensor.get_thread_buffer()(i) = idx_local;
311  }
312  });
313  }
314 
315  public:
316  template <typename YDistributedTensor_, typename ReduceFunc>
317  CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
318  {
319  reduce_impl<false>(y_tensor, y_tensor, reduce_func);
320  }
321 
322  template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
323  CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
324  YIndexDistributedTensor_& y_index_tensor,
325  const ReduceFunc& reduce_func)
326  {
327  reduce_impl<Problem::kOutputIndex>(y_tensor, y_index_tensor, reduce_func);
328  }
329 };
330 
331 // BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3)
332 template <typename Problem_, typename Policy_ = void>
334 {
336  using BlockShape = typename Problem::BlockShape;
337 
338  template <typename YDistributedTensor_>
340  {
341  constexpr index_t num_reduce_warps = [&]() {
342  using Dstr = typename YDistributedTensor_::StaticTileDistribution;
343  using DstrEncode = typename Dstr::DstrEncode;
344  using DstrEncodeDetail = typename DstrEncode::detail;
345 
346  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
347 
348  constexpr index_t idim_p_warp = 0;
349 
350  index_t len_ = 1;
351  static_for<0, NDimR, 1>{}([&](auto idim_r) {
352  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
353  {
354  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
355  len_ *= r_length;
356  }
357  });
358  return len_;
359  }();
360  return num_reduce_warps;
361  }
362 
363  // return in byte
364  template <typename YDistributedTensor_>
366  {
367  using DataType = typename YDistributedTensor_::DataType;
368  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
369 
370  // we need to store all data from every wave into smem
371  // e.g. 2x2 reduce along N
372  // -------------> reduce N
373  // | w0 | w1 | ___> | w01 |
374  // | w2 | w3 | | w23 |
375  //
376  // -> store data from every wave into LDS
377  //
378  //
379  // -------------> reduce N
380  // | w0 | w1 | w2 | w3 | -----> | w0123 |
381  //
382  // -> also store data from every wave into LDS
383  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
384  return num_warps * thread_buf_size * sizeof(DataType);
385  }
386 
387  // return in byte - separate shared memory size calculation for indices
388  template <typename YIndexDistributedTensor_>
390  {
391  using IndexDataType = typename YIndexDistributedTensor_::DataType;
392  constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
393  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
394  return num_warps * thread_buf_size * sizeof(IndexDataType);
395  }
396 
397  private:
398  template <bool kProcessIndex,
399  typename YDistributedTensor_,
400  typename YIndexDistributedTensor_,
401  typename ReduceFunc>
402  CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
403  YIndexDistributedTensor_& y_index_tensor,
404  void* smem,
405  void* smem_indices_ptr,
406  const ReduceFunc& reduce_func)
407  {
408  using DataType = typename YDistributedTensor_::DataType;
409  using IndexDataType = typename YIndexDistributedTensor_::DataType;
410 
411  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
412 
413  DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
414  IndexDataType* smem_indices = nullptr;
415  if constexpr(kProcessIndex)
416  {
417  smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
418  }
419 
420  const index_t lane_id = get_lane_id();
421  const index_t warp_id = get_warp_id();
422 
423  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
424  constexpr index_t num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
425 
426  if constexpr(num_reduce_warps == 1)
427  return;
428 
429  // Each warp's lane 0 writes its partial results to shared memory
430  const index_t smem_offset = warp_id;
431  if(lane_id == 0)
432  {
433  static_for<0, thread_buf_size, 1>{}([&](auto i) {
434  // Store the i-th element of this warp's thread_buffer into SMEM
435  smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
436  if constexpr(kProcessIndex)
437  {
438  smem_indices[smem_offset + i * num_warps] =
439  y_index_tensor.get_thread_buffer()[i];
440  }
441  });
442  }
443  block_sync_lds();
444 
445  // We let each warp holds a duplication to do reduction.
446  const index_t local_warp_id = warp_id / num_reduce_warps;
447  const index_t local_smem_os = local_warp_id * num_reduce_warps;
448 
449  static_for<0, thread_buf_size, 1>{}([&](auto i) {
450  DataType v[num_reduce_warps];
451  [[maybe_unused]] std::
452  conditional_t<kProcessIndex, IndexDataType[num_reduce_warps], IndexDataType> idx_v;
453 
454  static_for<0, num_reduce_warps, 1>{}([&](auto idx) {
455  v[idx] = smem_ptr[i * num_warps + local_smem_os + idx];
456  if constexpr(kProcessIndex)
457  {
458  idx_v[idx] = smem_indices[i * num_warps + local_smem_os + idx];
459  }
460  });
461 
462  static_assert(is_power_of_two_integer(num_reduce_warps),
463  "wrong! only support power of 2 reduction");
464 
465  constexpr index_t nstage = integer_log2_floor(num_reduce_warps);
466 
467  static_for<0, nstage, 1>{}([&](auto istage) {
468  constexpr index_t stride = 1 << istage.value;
469  static_for<0, num_reduce_warps, stride * 2>{}([&](auto idx_) {
470  constexpr index_t i0 = idx_();
471  constexpr index_t i1 = idx_ + stride;
472  if constexpr(i1 < num_reduce_warps)
473  {
474  if constexpr(kProcessIndex)
475  {
476  AccumulateWithIndex{}(reduce_func, v[i0], idx_v[i0], v[i1], idx_v[i1]);
477  }
478  else
479  {
480  Accumulate{}(reduce_func, v[i0], v[i1]);
481  }
482  }
483  });
484  });
485 
486  y_tensor.get_thread_buffer()(i) = v[0];
487  if constexpr(kProcessIndex)
488  {
489  y_index_tensor.get_thread_buffer()(i) = idx_v[0];
490  }
491  });
492  }
493 
494  public:
495  template <typename YDistributedTensor_, typename ReduceFunc>
496  CK_TILE_DEVICE void
497  operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
498  {
499  reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
500  }
501 
502  template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
503  CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
504  YIndexDistributedTensor_& y_index_tensor,
505  void* smem,
506  void* smem_indices,
507  const ReduceFunc& reduce_func)
508  {
509  reduce_impl<Problem::kOutputIndex>(
510  y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
511  }
512 };
513 
514 template <typename Problem_, typename Policy_ = void>
516 {
518  using BlockShape = typename Problem::BlockShape;
519 
520  template <typename YDistributedTensor_>
522  {
523  constexpr index_t num_reduce_warps = [&]() {
524  using Dstr = typename YDistributedTensor_::StaticTileDistribution;
525  using DstrEncode = typename Dstr::DstrEncode;
526  using DstrEncodeDetail = typename DstrEncode::detail;
527 
528  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
529 
530  constexpr index_t idim_p_warp = 0;
531 
532  index_t len_ = 1;
533  static_for<0, NDimR, 1>{}([&](auto idim_r) {
534  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
535  {
536  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
537  len_ *= r_length;
538  }
539  });
540  return len_;
541  }();
542  return num_reduce_warps;
543  }
544 
545  // return in byte
546  template <typename YDistributedTensor_>
548  {
549  using DataType = typename YDistributedTensor_::DataType;
550  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
551 
552  // we need to store all data from every wave into smem
553  // e.g. 2x2 reduce along N
554  // -------------> reduce N
555  // | w0 | w1 | ___> | w01 |
556  // | w2 | w3 | | w23 |
557  //
558  // -> store data from every wave into LDS
559  //
560  //
561  // -------------> reduce N
562  // | w0 | w1 | w2 | w3 | -----> | w0123 |
563  //
564  // -> also store data from every wave into LDS
565  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
566  return num_warps * thread_buf_size * sizeof(DataType);
567  }
568 
569  // return in byte - separate shared memory size calculation for indices
570  template <typename YIndexDistributedTensor_>
572  {
573  using IndexDataType = typename YIndexDistributedTensor_::DataType;
574  constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
575  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
576  return num_warps * thread_buf_size * sizeof(IndexDataType);
577  }
578 
579  private:
580  template <bool kProcessIndex,
581  typename YDistributedTensor_,
582  typename YIndexDistributedTensor_,
583  typename ReduceFunc>
584  CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
585  YIndexDistributedTensor_& y_index_tensor,
586  void* smem,
587  void* smem_indices_ptr,
588  const ReduceFunc& reduce_func)
589  {
590  using DataType = typename YDistributedTensor_::DataType;
591  using IndexDataType = typename YIndexDistributedTensor_::DataType;
592 
593  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
594 
595  DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
596  IndexDataType* smem_indices = nullptr;
597  if constexpr(kProcessIndex)
598  {
599  smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
600  }
601 
602  const index_t lane_id = get_lane_id();
603  const index_t warp_id = get_warp_id();
604  constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
605  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
606  const index_t smem_offset = warp_id;
607 
608  // skip if nonthing to do
609  if constexpr(num_reduce_warps == 1)
610  return;
611 
612  // store into smem only for lane-0 within one warp
613  if(lane_id == 0)
614  {
615  static_for<0, thread_buf_size, 1>{}([&](auto i) {
616  smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
617  if constexpr(kProcessIndex)
618  {
619  smem_indices[smem_offset + i * num_warps] =
620  y_index_tensor.get_thread_buffer()[i];
621  }
622  });
623  }
624  block_sync_lds();
625 
626  // load from smem. here we let everythread to do compute :)
627  index_t local_warp_id = warp_id / num_reduce_warps;
628  index_t local_smem_os = local_warp_id * num_reduce_warps;
629 
630  DataType all_scratch[thread_buf_size * num_reduce_warps];
631  [[maybe_unused]] std::conditional_t<kProcessIndex,
632  IndexDataType[thread_buf_size * num_reduce_warps],
633  IndexDataType> all_indices;
634 
635  // Load data from shared memory
636  static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
637  static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
638  all_scratch[i_0 * num_reduce_warps + i_1] =
639  smem_ptr[i_0 * num_warps + local_smem_os + i_1];
640 
641  if constexpr(kProcessIndex)
642  {
643  all_indices[i_0 * num_reduce_warps + i_1] =
644  smem_indices[i_0 * num_warps + local_smem_os + i_1];
645  }
646  });
647  });
648  block_sync_lds(); // TODO: we don't need sync here
649 
650  // Perform reduction
651  static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
652  // TODO: use descriptor for this
653  auto v_local = all_scratch[i_0 * num_reduce_warps];
654 
655  IndexDataType idx_local{};
656  if constexpr(kProcessIndex)
657  {
658  idx_local = all_indices[i_0 * num_reduce_warps];
659  }
660 
661  // further reduce mean/var
662  static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
663  constexpr auto i_1 = number<i_1_n1 + 1>{};
664  const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
665 
666  if constexpr(kProcessIndex)
667  {
668  const IndexDataType idx_remote = all_indices[i_0 * num_reduce_warps + i_1];
669 
670  bool changed = false;
671  v_local = reduce_func(v_local, v_remote, changed);
672  if(changed)
673  {
674  idx_local = idx_remote;
675  }
676  }
677  else
678  {
679  v_local = reduce_func(v_local, v_remote);
680  }
681  });
682 
683  y_tensor.get_thread_buffer()(i_0) = v_local;
684  if constexpr(kProcessIndex)
685  {
686  y_index_tensor.get_thread_buffer()(i_0) = idx_local;
687  }
688  });
689  }
690 
691  public:
692  template <typename YDistributedTensor_, typename ReduceFunc>
693  CK_TILE_DEVICE void
694  operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
695  {
696  reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
697  }
698 
699  template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
700  CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
701  YIndexDistributedTensor_& y_index_tensor,
702  void* smem,
703  void* smem_indices,
704  const ReduceFunc& reduce_func)
705  {
706  reduce_impl<Problem::kOutputIndex>(
707  y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
708  }
709 };
710 
711 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:245
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:762
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:78
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:455
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition: static_distributed_tensor.hpp:159
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1026
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
int32_t index_t
Definition: ck.hpp:299
Definition: reduce_operator_accumulate.hpp:41
Accumulate with index tracking reductions, provides deterministic first occurring index.
Definition: reduce_operator_accumulate.hpp:12
Definition: block_reduce2d.hpp:334
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:365
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, void *smem, void *smem_indices, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:503
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:335
static constexpr CK_TILE_HOST_DEVICE index_t GetIndicesSmemSize()
Definition: block_reduce2d.hpp:389
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:497
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:336
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:339
Definition: block_reduce2d.hpp:46
constexpr CK_TILE_DEVICE BlockReduce2d()
Definition: block_reduce2d.hpp:52
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, const ReduceFunc &reduce_func, const IndexCalculatorFunc &index_calculator, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:128
typename Problem::ComputeDataType ComputeDataType
Definition: block_reduce2d.hpp:50
static CK_TILE_DEVICE auto MakeYBlockTile()
Definition: block_reduce2d.hpp:166
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:107
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:48
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, const ComputeDataType &reduce_init, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:208
typename Problem::XDataType XDataType
Definition: block_reduce2d.hpp:49
static CK_TILE_DEVICE auto MakeYIndexBlockTile()
Definition: block_reduce2d.hpp:185
Definition: block_reduce2d.hpp:516
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, void *smem, void *smem_indices, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:700
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:517
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:694
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:518
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:521
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:547
static constexpr CK_TILE_HOST_DEVICE index_t GetIndicesSmemSize()
Definition: block_reduce2d.hpp:571
Definition: block_reduce2d.hpp:224
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:317
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:323
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:225
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:49
Definition: functional.hpp:43