/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp Source File
reduce2d_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 // Reduce2d Kernel:
12 // =======================================
13 // This kernel implements a 2D reduction operation that reduces data along the second dimension
14 // of a matrix. The reduction is performed in multiple hierarchical stages.
15 
16 namespace ck_tile {
17 
18 template <typename Problem_, typename Policy_ = Reduce2dDefaultPolicy>
19 struct Reduce
20 {
23 
27 
28  static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
29 
30  private:
31  // Helper function to calculate optimal vector size for input tensor
32  template <typename InputShape, typename ReduceDims>
33  static constexpr index_t CalculateInputVectorSize()
34  {
35  using S = typename Problem::BlockShape;
36  constexpr index_t memory_vector_size = 16 / sizeof(XDataType);
37  constexpr index_t thread_tile_vector_size = S::ThreadTile_N;
38 
39  // Check if innermost reduce dimension is the last dimension (stride 1).
40  constexpr auto innermost_reduce_dim = ReduceDims{}.at(number<ReduceDims{}.size() - 1>{});
41  constexpr bool is_innermost_contiguous = (innermost_reduce_dim == InputShape{}.size() - 1);
42 
43  // If innermost reduce dimension is not the last dim (not contiguous), limit vectorization
44  constexpr index_t stride_based_vector_size =
45  is_innermost_contiguous ? ck_tile::min(memory_vector_size, thread_tile_vector_size) : 1;
46 
47  return stride_based_vector_size;
48  }
49 
50  // Helper function to calculate optimal vector size for output tensor
51  static constexpr index_t CalculateOutputVectorSize()
52  {
53  using S = typename Problem::BlockShape;
54  constexpr index_t memory_vector_size = 16 / sizeof(YDataType);
55  constexpr index_t thread_tile_vector_size = S::ThreadTile_M;
56  constexpr index_t vector_size = ck_tile::min(memory_vector_size, thread_tile_vector_size);
57 
58  return vector_size;
59  }
60 
61  public:
62  template <typename InputShape, typename InputStrides, typename KeptDim, typename ReduceDims>
64  YDataType* p_y,
65  InputShape input_shape,
66  InputStrides input_strides,
67  KeptDim kept_dim,
68  ReduceDims reduce_dims) const
69  {
70  using S = typename Problem::BlockShape;
71  const auto iM = get_block_id() * S::Block_M;
72 
73  static_assert(kept_dim.size() + reduce_dims.size() == InputShape::size(),
74  "Size of kept dimensions + reduced dimensions must equal input tensor rank");
75 
76  // Extract lengths based on kept and reduced dimensions
77  const auto kept_lens = [&]() {
78  return generate_tuple([&](auto I) { return input_shape.at(number<kept_dim.at(I)>{}); },
79  number<kept_dim.size()>{});
80  }();
81  const auto reduce_lens = [&]() {
82  return generate_tuple(
83  [&](auto I) { return input_shape.at(number<reduce_dims.at(I)>{}); },
84  number<reduce_dims.size()>{});
85  }();
86 
87  const auto kept_merge_transform = make_merge_transform(kept_lens);
88  const auto reduce_merge_transform = make_merge_transform(reduce_lens);
89 
90  auto reduce_func = typename Problem::ReduceOp{};
91  const XDataType custom_padding_value =
92  type_convert<XDataType>(reduce_func.template GetIdentityValue<ComputeDataType>());
93 
94  // Calculate optimal vector size for input tensor
95  constexpr auto x_tensor_vector_size = CalculateInputVectorSize<InputShape, ReduceDims>();
96 
97  // Create input tensor view with custom padding value
98  auto desc = make_naive_tensor_descriptor(
99  input_shape, input_strides, number<x_tensor_vector_size>{}, number<1>{});
100 
101  // Create buffer view with custom padding value
102  auto buffer_view = make_buffer_view<address_space_enum::global>(
103  p_x, desc.get_element_space_size(), custom_padding_value);
104 
105  // Create tensor view with custom padding
106  const auto x_tensor = tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
107  const auto transformed_x_tensor = pad_tensor_view(
108  transform_tensor_view(x_tensor,
109  make_tuple(kept_merge_transform, reduce_merge_transform),
110  make_tuple(kept_dim, reduce_dims),
113  sequence<0, 1>{});
114 
115  // Calculate strides for output tensor based on its own dimensions
116  const auto kept_strides = [&]() {
117  return generate_tuple(
118  [&](auto I) {
119  // Calculate stride for dimension I as product of all following dimensions
120  index_t stride = 1;
121  static_for<I + 1, kept_dim.size(), 1>{}(
122  [&](auto J) { stride *= kept_lens.at(number<J>{}); });
123  return stride;
124  },
125  number<kept_dim.size()>{});
126  }();
127 
128  // Calculate optimal vector size for output tensor
129  constexpr auto y_tensor_vector_size = CalculateOutputVectorSize();
130 
131  const auto y_m = make_naive_tensor_view<address_space_enum::global>(
132  p_y, kept_lens, kept_strides, number<y_tensor_vector_size>{}, number<1>{});
133 
134  // Transform output tensor to 1D merged view
135  // This creates a view compatible with the 2D reduction pattern
136  const auto y_merged = transform_tensor_view(
137  y_m,
138  make_tuple(kept_merge_transform),
139  make_tuple(typename arithmetic_sequence_gen<0, kept_dim.size(), 1>::type{}),
141 
142  auto x_window = make_tile_window(transformed_x_tensor,
144  {iM, 0},
145  Policy::template MakeXBlockTileDistribution<Problem>());
146 
147  auto y_window = make_tile_window(y_merged, make_tuple(number<S::Block_M>{}), {iM});
148 
149  __shared__ char smem[Policy::template GetSmemSize<Problem>()];
150 
151  // Get the merged dimension size from the transformed tensor
152  const auto merged_reduce_len =
153  transformed_x_tensor.get_tensor_descriptor().get_lengths().at(number<1>{});
154  index_t num_n_tile_iteration =
155  __builtin_amdgcn_readfirstlane(integer_divide_ceil(merged_reduce_len, S::Block_N));
156 
157  auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
158  auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
159  auto block_reduce2d_cross_warp_sync =
160  Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
161 
162  using XTensorType = decltype(load_tile(x_window));
163  auto y_compute = block_reduce2d.template MakeYBlockTile<XTensorType>();
164  set_tile(y_compute, reduce_func.template GetIdentityValue<ComputeDataType>());
165 
166  for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
167  {
168  const auto x = load_tile(x_window);
169  block_reduce2d(x, y_compute, reduce_func);
170  move_tile_window(x_window, {0, S::Block_N});
171  }
172 
173  block_reduce2d_sync(y_compute, reduce_func);
174  block_reduce2d_cross_warp_sync(y_compute, smem, reduce_func);
175 
176  store_tile(y_window, cast_tile<YDataType>(y_compute));
177  }
178 
194  template <typename InputStrides>
195  CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim,
196  InputStrides input_strides)
197  {
198  using S = typename Problem::BlockShape;
199 
200  if(y_continous_dim % S::ThreadTile_N != 0)
201  {
202  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
203  {
204  CK_TILE_ERROR("Total reduction size should be a multiple of ThreadTile_N!");
205  }
206  return false;
207  }
208 
209  if(input_strides.at(number<input_strides.size() - 1>{}) != 1)
210  {
211  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
212  {
214  "Input tensor's last stride must be 1 to support correct vector access!");
215  }
216  return false;
217  }
218 
219  return true;
220  }
221 };
222 
223 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
Definition: reduce2d_kernel.hpp:20
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: reduce2d_kernel.hpp:21
ck_tile::remove_cvref_t< typename Problem::YDataType > YDataType
Definition: reduce2d_kernel.hpp:26
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: reduce2d_kernel.hpp:22
static constexpr index_t kBlockSize
Definition: reduce2d_kernel.hpp:28
static CK_TILE_HOST bool IsSupportedArgument(index_t y_continous_dim, InputStrides input_strides)
Validates if the given arguments are supported by the 2D reduction kernel.
Definition: reduce2d_kernel.hpp:195
CK_TILE_DEVICE void operator()(const XDataType *p_x, YDataType *p_y, InputShape input_shape, InputStrides input_strides, KeptDim kept_dim, ReduceDims reduce_dims) const
Definition: reduce2d_kernel.hpp:63
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: reduce2d_kernel.hpp:25
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition: reduce2d_kernel.hpp:24
Definition: sequence.hpp:284
Definition: buffer_view.hpp:38
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define CK_TILE_ENV(name)
Definition: env.hpp:145