include/ck_tile/ops/reduce/block/block_reduce2d.hpp Source File

include/ck_tile/ops/reduce/block/block_reduce2d.hpp Source File#

Composable Kernel: include/ck_tile/ops/reduce/block/block_reduce2d.hpp Source File
block_reduce2d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 template <typename Problem_, typename Policy_ = void>
12 {
13  // in-thread reduction
15  using XDataType = typename Problem::XDataType;
16  using ComputeDataType = typename Problem::ComputeDataType;
17 
19 
20  template <typename XDistributedTensor_,
21  typename YDistributedTensor_,
22  typename ReduceFunc,
23  typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
24  CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
25  YDistributedTensor_& y_tensor,
26  const ReduceFunc& reduce_func,
27  ReducePacksPerXDim = {})
28  {
29  sweep_tile<XDistributedTensor_>(
30  [&](auto... idx_) {
31  constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
32  y_tensor(idx_0) = reduce_func(
33  y_tensor(idx_0), ck_tile::type_convert<ComputeDataType>(x_tensor[idx_])...);
34  },
35  ReducePacksPerXDim{});
36 #if 0
37  constexpr auto I0 = number<0>{};
38  constexpr auto I1 = number<1>{};
39  constexpr auto spans = XDistributedTensor_::get_distributed_spans();
40 
41  // FIXME: hard coded to reduce 2nd axis
42  sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
43  constexpr auto y_dstr_idx = make_tuple(dstr_idx_i0);
44 
45  auto y = y_tensor[y_dstr_idx];
46 
47  sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
48  constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
49  const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
50 
51  y = reduce_func(y, x);
52  });
53 
54  y_tensor(y_dstr_idx) = y;
55  });
56 #endif
57  }
58 
59  template <typename XDistributedTensor_>
61  {
62  static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
63 
64  // FIXME: hard coded to reduce 2nd axis
65  constexpr auto reduce_dims = sequence<1>{};
66 
67  constexpr auto dstr =
69  XDistributedTensor_::get_tile_distribution()
70  .get_static_tile_distribution_encoding(),
71  reduce_dims));
72 
73  auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
74 
75  return tensor;
76  }
77 
78  template <typename XDistributedTensor_,
79  typename ReduceFunc,
80  typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
81  CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor,
82  const ComputeDataType& reduce_init,
83  const ReduceFunc& reduce_func,
84  ReducePacksPerXDim = {})
85  {
86  auto y_tensor = MakeYBlockTile<XDistributedTensor_>();
87  set_tile(y_tensor, reduce_init);
88  (*this)(x_tensor, y_tensor, reduce_func, ReducePacksPerXDim{});
89 
90  return y_tensor;
91  }
92 };
93 
94 template <typename Problem_, typename Policy_ = void>
96 {
98 
99  template <typename YDistributedTensor_, typename ReduceFunc>
100  CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
101  {
102  using Dstr = typename YDistributedTensor_::StaticTileDistribution;
103  using DstrEncode = typename Dstr::DstrEncode;
104  using DstrEncodeDetail = typename DstrEncode::detail;
105 
106  constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
107  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
108 
109  constexpr index_t idim_p_lane = NDimP - 1;
110 
111  // const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
112  // const auto rs_idx =
113  // y_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
114 
115  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
116 
117  // loop over thread data
118  static_for<0, thread_buf_size, 1>{}([&](auto i) {
119  auto v_local = y_tensor.get_thread_buffer()[i];
120 
121  // cross-lane reduce for replication
122  // only reduce on R dimension correspond to lane
123  // (lane id maps to this R dimension)
124  static_for<0, NDimR, 1>{}([&](auto idim_r) {
125  // FIXME: nasty to use does_p_own_r_
126  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
127  {
128  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
129 
130  constexpr index_t lid_over_rid_derivative =
131  DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
132 
133  static_assert(is_power_of_two_integer(r_length),
134  "wrong! only support power of 2 reduction");
135 
136  constexpr index_t nstage = integer_log2_floor(r_length);
137 
138  // reduction sweep forward
139  static_for<0, nstage, 1>{}([&](auto istage) {
140  // xor
141  index_t src_lane =
142  (__lane_id()) ^
143  (number<lid_over_rid_derivative << istage.value>{}.value);
144 
145  // pull data from remote lane
146  const auto v_remote = warp_shuffle(v_local, src_lane);
147 
148  // reduce
149  v_local = reduce_func(v_local, v_remote);
150  });
151  }
152  });
153 
154  // TODO - Do we need to broadcast to other lane?
155  y_tensor.get_thread_buffer()(i) = v_local;
156  });
157  }
158 };
159 
160 template <typename Problem_, typename Policy_ = void>
162 {
164  using BlockShape = typename Problem::BlockShape;
165 
166  template <typename YDistributedTensor_>
168  {
169  constexpr index_t num_reduce_warps = [&]() {
170  using Dstr = typename YDistributedTensor_::StaticTileDistribution;
171  using DstrEncode = typename Dstr::DstrEncode;
172  using DstrEncodeDetail = typename DstrEncode::detail;
173 
174  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
175 
176  constexpr index_t idim_p_warp = 0;
177 
178  index_t len_ = 1;
179  static_for<0, NDimR, 1>{}([&](auto idim_r) {
180  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
181  {
182  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
183  len_ *= r_length;
184  }
185  });
186  return len_;
187  }();
188  return num_reduce_warps;
189  }
190 
191  // return in byte
192  template <typename YDistributedTensor_>
194  {
195  using DataType = typename YDistributedTensor_::DataType;
196  // constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
197 
198  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
199 
200  // we need to store all data from every wave into smem
201  // e.g. 2x2 reduce along N
202  // -------------> reduce N
203  // | w0 | w1 | ___> | w01 |
204  // | w2 | w3 | | w23 |
205  //
206  // -> store data from every wave into LDS
207  //
208  //
209  // -------------> reduce N
210  // | w0 | w1 | w2 | w3 | -----> | w0123 |
211  //
212  // -> also store data from every wave into LDS
213  constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
214  return num_warps * thread_buf_size * sizeof(DataType);
215  }
216 
217  template <typename YDistributedTensor_, typename ReduceFunc>
218  CK_TILE_DEVICE void
219  operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
220  {
221  using DataType = typename YDistributedTensor_::DataType;
222 
223  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
224 
225  DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
226  const index_t lane_id = get_lane_id();
227  const index_t warp_id = get_warp_id();
228  constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
229  constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
230  const index_t smem_offset = warp_id;
231 
232  // skip if nonthing to do
233  if constexpr(num_reduce_warps == 1)
234  return;
235 
236  // store into smem only for lane-0 within one warp
237  if(lane_id == 0)
238  {
239  static_for<0, thread_buf_size, 1>{}([&](auto i) {
240  smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
241  });
242  }
243  block_sync_lds();
244 
245  // load from smem. here we let everythread to do compute :)
246  index_t local_warp_id = warp_id / num_reduce_warps;
247  index_t local_smem_os = local_warp_id * num_reduce_warps;
248  DataType all_scratch[thread_buf_size * num_reduce_warps];
249  static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
250  static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
251  all_scratch[i_0 * num_reduce_warps + i_1] =
252  smem_ptr[i_0 * num_warps + local_smem_os + i_1];
253  });
254  });
255  block_sync_lds(); // TODO: we don't need sync here
256 
257  static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
258  // TODO: use descriptor for this
259  auto v_local = all_scratch[i_0 * num_reduce_warps];
260 
261  // further reduce mean/var
262  static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
263  constexpr auto i_1 = number<i_1_n1 + 1>{};
264  const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
265 
266  // reduce
267  v_local = reduce_func(v_local, v_remote);
268  });
269 
270  y_tensor.get_thread_buffer()(i_0) = v_local;
271  });
272  }
273 };
274 
275 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:725
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE index_t get_lane_id()
Definition: arch.hpp:69
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:63
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:63
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:455
CK_TILE_DEVICE index_t get_warp_id()
Definition: arch.hpp:71
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1017
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: block_reduce2d.hpp:162
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:193
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:163
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:219
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:164
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:167
Definition: block_reduce2d.hpp:12
constexpr CK_TILE_DEVICE BlockReduce2d()
Definition: block_reduce2d.hpp:18
typename Problem::ComputeDataType ComputeDataType
Definition: block_reduce2d.hpp:16
static CK_TILE_DEVICE auto MakeYBlockTile()
Definition: block_reduce2d.hpp:60
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:24
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:14
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, const ComputeDataType &reduce_init, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:81
typename Problem::XDataType XDataType
Definition: block_reduce2d.hpp:15
Definition: block_reduce2d.hpp:96
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:100
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:97
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:52
Definition: functional.hpp:43