/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/kernel/multi_reduce2d_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/kernel/multi_reduce2d_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/kernel/multi_reduce2d_kernel.hpp Source File
multi_reduce2d_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
14 
15 // Multi Reduce2d Unified Kernel:
16 // =======================================
17 // This kernel implements multiple 2D reduction operations that reduce data along the specified
18 // dimensions of a matrix. It supports both single-block (threadwise) and multi-block
19 
20 namespace ck_tile {
21 
22 template <typename Problem_,
23  typename Policy_ = Reduce2dDefaultPolicy,
24  bool ForceMultiBlock_ = false>
26 {
29 
30  static constexpr bool ForceMultiBlock = ForceMultiBlock_; // false: threadwise, true: multiblock
31 
35 
37 
38  static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
39 
40  CK_TILE_HOST static constexpr auto BlockSize()
41  {
42  return is_wave32() ? kBlockSize / 2 : kBlockSize;
43  }
44 
45  private:
46  // Helper function to calculate optimal vector size for input tensor
47  template <typename InputShape, typename ReduceDims>
48  static constexpr index_t CalculateInputVectorSize()
49  {
50  using S = typename Problem::BlockShape;
51  constexpr index_t memory_vector_size = 16 / sizeof(XDataType); // Vectorization
52  constexpr index_t thread_tile_vector_size =
53  S::ThreadTile_N; // In the continuous dimension, within the tile
54 
55  constexpr auto innermost_reduce_dim = ReduceDims{}.at(number<ReduceDims{}.size() - 1>{});
56  constexpr bool is_innermost_contiguous = (innermost_reduce_dim == InputShape{}.size() - 1);
57 
58  constexpr index_t stride_based_vector_size =
59  is_innermost_contiguous
60  ? ck_tile::min(memory_vector_size, thread_tile_vector_size)
61  : 1; // Move at "vectorization" steps if continuous otherwise 1 step
62 
63  return stride_based_vector_size;
64  }
65 
66  static constexpr index_t CalculateOutputVectorSize()
67  {
68  using S = typename Problem::BlockShape;
69  constexpr index_t memory_vector_size = 16 / sizeof(YDataType);
70  constexpr index_t thread_tile_vector_size = S::ThreadTile_M;
71  constexpr index_t vector_size = ck_tile::min(memory_vector_size, thread_tile_vector_size);
72 
73  return vector_size;
74  }
75 
76  public:
77  // Overload for threadwise version (no InterblockReduceOps parameter)
78  // This version uses the same reduce_ops for interblock reduction
79  template <typename InputShape,
80  typename InputStrides,
81  typename KeptDim,
82  typename ReduceDims,
83  typename ElementwiseOps,
84  typename AccumulatorOps>
86  YDataType* p_y_tuple,
87  InputShape input_shape,
88  InputStrides input_strides,
89  KeptDim kept_dim,
90  ReduceDims reduce_dims,
91  index_t output_tensor_offset,
92  ElementwiseOps elementwise_ops,
93  AccumulatorOps accumulator_ops) const
94  {
95  // For single-block case, use the same reduce ops for interblock reduction
96  // (though they won't be used since block_group_size will be 1)
97  auto reduce_ops = typename Problem::ReduceOp{};
98  (*this)(p_x,
99  p_y_tuple,
100  input_shape,
101  input_strides,
102  kept_dim,
103  reduce_dims,
104  output_tensor_offset,
105  elementwise_ops,
106  accumulator_ops,
107  reduce_ops); // Use reduce_ops as interblock_reduce_ops
108  }
109 
110  // Main operator overload
111  template <typename InputShape,
112  typename InputStrides,
113  typename KeptDim,
114  typename ReduceDims,
115  typename ElementwiseOps,
116  typename AccumulatorOps,
117  typename InterblockReduceOps>
119  YDataType* p_y_tuple,
120  InputShape input_shape,
121  InputStrides input_strides,
122  KeptDim kept_dim,
123  ReduceDims reduce_dims,
124  index_t output_tensor_offset,
125  ElementwiseOps elementwise_ops,
126  AccumulatorOps accumulator_ops,
127  InterblockReduceOps interblock_reduce_ops) const
128  {
129  static_assert(
130  ElementwiseOps::size() == Problem::ReduceOp::size() &&
131  AccumulatorOps::size() == Problem::ReduceOp::size() &&
132  InterblockReduceOps::size() == Problem::ReduceOp::size(),
133  "Error: All operations tuple size must match the number of reduction operations");
134 
135  using S = typename Problem::BlockShape;
136  auto reduce_ops = typename Problem::ReduceOp{};
137 
138  const auto number_operations = reduce_ops.size();
139 
140  static_assert(number_operations > 0,
141  "Error: At least one reduction operation must be specified!");
142 
143  static_assert(kept_dim.size() + reduce_dims.size() == InputShape::size(),
144  "Size of kept dimensions + reduced dimensions must equal input tensor rank");
145 
146  const auto kept_lens = [&]() {
147  return generate_tuple([&](auto I) { return input_shape.at(number<kept_dim.at(I)>{}); },
148  number<kept_dim.size()>{});
149  }();
150  const auto reduce_lens = [&]() {
151  return generate_tuple(
152  [&](auto I) { return input_shape.at(number<reduce_dims.at(I)>{}); },
153  number<reduce_dims.size()>{});
154  }();
155 
156  // Calculate total reduction length
157  int total_reduce_len = 1;
158  static_for<0, reduce_lens.size(), 1>{}(
159  [&](auto i) { total_reduce_len *= reduce_lens.at(i); });
160 
161  // Early exit for empty tensors (reduce_total_length == 0)
162  // This can happen when any dimension in reduce_lens is 0
163  if(total_reduce_len == 0)
164  {
165  return;
166  }
167 
168  const TilePartitioner partitioner{total_reduce_len};
169 
170  // Determine strategy: single-block or multi-block
171  auto [num_n_tile_iteration, block_group_size] = partitioner.GetBlockGroupParams();
172 
173  constexpr index_t output_vector_size = CalculateOutputVectorSize();
174 
175  const auto block_global_id = get_block_id(); // Hardware block id
176 
177  // Get tile indices
178  index_t block_group_id;
179  if constexpr(ForceMultiBlock)
180  {
181  const auto [tile_idx, local_idx] =
182  partitioner.GetOutputTileIndexMultiBlock(block_global_id, block_group_size);
183  block_group_id = tile_idx;
184  }
185  else
186  {
187  block_group_id = partitioner.GetOutputTileIndex(block_global_id);
188  }
189 
190  const auto kept_merge_transform =
191  make_merge_transform(kept_lens); // Dimension(s) not reduced are being flattened
192  const auto reduce_merge_transform =
193  make_merge_transform(reduce_lens); // Dimension(s) to reduce are being flattened
194 
195  const auto custom_padding_values = ck_tile::apply(
196  [](auto... args) {
197  return ck_tile::make_tuple(args.template GetIdentityValue<XDataType>()...);
198  },
199  reduce_ops); // Get the identity element for each operation
200 
201  constexpr auto x_tensor_vector_size = CalculateInputVectorSize<InputShape, ReduceDims>();
202 
203  auto desc = make_naive_tensor_descriptor(
204  input_shape, input_strides, number<x_tensor_vector_size>{}, number<1>{});
205 
206  __shared__ char smem[Policy::template GetSmemSize<Problem>()];
207 
208  auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
209  auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
210  auto block_reduce2d_cross_warp_sync =
211  Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
212 
213  auto [m_offset, n_offset] = partitioner.GetInputTileOffsets(
214  block_global_id, block_group_size, num_n_tile_iteration);
215 
217  auto buffer_view = make_buffer_view<address_space_enum::global>(
218  p_x, desc.get_element_space_size(), custom_padding_values.get(number<i>{}));
219 
220  const auto x_tensor =
221  tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
222  const auto transformed_x_tensor = pad_tensor_view(
223  transform_tensor_view(x_tensor,
224  make_tuple(kept_merge_transform, reduce_merge_transform),
225  make_tuple(kept_dim, reduce_dims),
228  sequence<0, 1>{});
229 
230  auto x_window =
231  make_tile_window(transformed_x_tensor,
233  {m_offset, n_offset},
234  Policy::template MakeXBlockTileDistribution<Problem>());
235 
236  using ComputeDataTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
237 
238  auto y_compute = block_reduce2d.template MakeYBlockTile<ComputeDataTensorType>();
239 
240  set_tile(y_compute,
241  reduce_ops.get(number<i>{}).template GetIdentityValue<ComputeDataType>());
242 
243  // Reduction loop
244  for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
245  {
246  auto x = load_tile(x_window);
247  auto x_compute = cast_tile<ComputeDataType>(x);
248 
249  tile_elementwise_inout(elementwise_ops.get(number<i>{}), x_compute, x_compute);
250  block_reduce2d(x_compute, y_compute, reduce_ops.get(number<i>{}));
251 
252  move_tile_window(x_window, {0, S::Block_N});
253  }
254 
255  block_reduce2d_sync(y_compute, reduce_ops.get(number<i>{}));
256  block_reduce2d_cross_warp_sync(
257  y_compute, static_cast<void*>(smem), reduce_ops.get(number<i>{}));
258 
259  // Determine if this thread should perform the output operation
260  // We want threads that handle the first elements in the N (reduction) dimension
261  const auto tile_dist = y_compute.get_tile_distribution();
262  const auto ps_idx = get_partition_index(tile_dist);
263  const auto rs_idx = tile_dist.calculate_rs_index_from_ps_index(ps_idx);
264 
265  // Check if this thread is responsible for the first N-dimension element
266  // In the tile distribution, dimension 1 corresponds to the N dimension
267  const bool is_first_n_thread = (rs_idx[number<1>{}] == 0);
268 
269  if(is_first_n_thread)
270  {
271  tile_elementwise_inout(accumulator_ops.get(number<i>{}), y_compute, y_compute);
272  const index_t output_offset =
273  (i * output_tensor_offset) + // operation offset
274  partitioner.GetOutputTileOffset(block_group_id); // tile offset
275  // Single-block vs multi-block output strategy
276  if constexpr(!ForceMultiBlock)
277  {
278  // Single-block case: direct store without atomics
279  auto y_tensor_view = make_naive_tensor_view<address_space_enum::global>(
280  p_y_tuple + output_offset,
281  make_tuple(S::Block_M),
282  make_tuple(1),
284  number<1>{});
285 
286  auto y_window = make_tile_window(y_tensor_view,
288  {0},
289  y_compute.get_tile_distribution());
290 
291  auto y_output = cast_tile<YDataType>(y_compute);
292  store_tile(y_window, y_output); // Direct store, no atomics
293  }
294  else
295  {
296  // Multi-block case: use atomic operations for interblock reduction
297 
298  auto y_tensor_view =
299  make_naive_tensor_view<address_space_enum::global,
300  interblock_reduce_ops.get(number<i>{}).GetAtomic()>(
301  p_y_tuple + output_offset,
302  make_tuple(S::Block_M),
303  make_tuple(1),
305  number<1>{});
306 
307  auto y_window = make_tile_window(y_tensor_view,
309  {0},
310  y_compute.get_tile_distribution());
311 
312  auto y_output = cast_tile<YDataType>(y_compute);
313  update_tile(y_window, y_output); // Atomic update
314  }
315  }
316  });
317  }
318 
334  template <typename InputStrides>
335  CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim,
336  InputStrides input_strides)
337  {
338  using S = typename Problem::BlockShape;
339 
340  if(y_continous_dim % S::ThreadTile_N != 0)
341  {
342  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
343  {
344  CK_TILE_ERROR("Total reduction size should be a multiple of ThreadTile_N!");
345  }
346  return false;
347  }
348 
349  if(input_strides.at(number<input_strides.size() - 1>{}) != 1)
350  {
351  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
352  {
354  "Input tensor's last stride must be 1 to support correct vector access!");
355  }
356  return false;
357  }
358 
359  return true;
360  }
361 };
362 
363 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition: tuple.hpp:526
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1690
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:486
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:206
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
Definition: multi_reduce2d_kernel.hpp:26
static constexpr CK_TILE_HOST auto BlockSize()
Definition: multi_reduce2d_kernel.hpp:40
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: multi_reduce2d_kernel.hpp:33
CK_TILE_DEVICE void operator()(const XDataType *p_x, YDataType *p_y_tuple, InputShape input_shape, InputStrides input_strides, KeptDim kept_dim, ReduceDims reduce_dims, index_t output_tensor_offset, ElementwiseOps elementwise_ops, AccumulatorOps accumulator_ops, InterblockReduceOps interblock_reduce_ops) const
Definition: multi_reduce2d_kernel.hpp:118
static constexpr bool ForceMultiBlock
Definition: multi_reduce2d_kernel.hpp:30
ck_tile::remove_cvref_t< typename Problem::YDataType > YDataType
Definition: multi_reduce2d_kernel.hpp:34
static constexpr index_t kBlockSize
Definition: multi_reduce2d_kernel.hpp:38
CK_TILE_DEVICE void operator()(const XDataType *p_x, YDataType *p_y_tuple, InputShape input_shape, InputStrides input_strides, KeptDim kept_dim, ReduceDims reduce_dims, index_t output_tensor_offset, ElementwiseOps elementwise_ops, AccumulatorOps accumulator_ops) const
Definition: multi_reduce2d_kernel.hpp:85
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition: multi_reduce2d_kernel.hpp:32
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: multi_reduce2d_kernel.hpp:28
static CK_TILE_HOST bool IsSupportedArgument(index_t y_continous_dim, InputStrides input_strides)
Validates if the given arguments are supported by the 2D multi reduction kernel.
Definition: multi_reduce2d_kernel.hpp:335
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: multi_reduce2d_kernel.hpp:27
TilePartitioner for 2D reduction operations.
Definition: multi_reduce2d_tile_partitioner.hpp:13
CK_TILE_HOST_DEVICE auto GetBlockGroupParams() const noexcept -> tuple< index_t, index_t >
Calculate the number of iterations and the number of blocks required to perform the reduction.
Definition: multi_reduce2d_tile_partitioner.hpp:53
Definition: buffer_view.hpp:35
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define CK_TILE_ENV(name)
Definition: env.hpp:145