/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp Source File
reduce2d_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 // Reduce2d Kernel:
12 // =======================================
13 // This kernel implements a 2D reduction operation that reduces data along the second dimension
14 // of a matrix. The reduction is performed in multiple hierarchical stages.
15 
16 namespace ck_tile {
17 
18 template <typename Problem_, typename Policy_ = Reduce2dDefaultPolicy>
19 struct Reduce
20 {
23 
27 
28  static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
29  CK_TILE_HOST static constexpr auto BlockSize()
30  {
31  return is_wave32() ? kBlockSize / 2 : kBlockSize;
32  }
33 
34  private:
35  // Helper function to calculate optimal vector size for input tensor
36  template <typename InputShape, typename ReduceDims>
37  static constexpr index_t CalculateInputVectorSize()
38  {
39  using S = typename Problem::BlockShape;
40  constexpr index_t memory_vector_size = 16 / sizeof(XDataType);
41  constexpr index_t thread_tile_vector_size = S::ThreadTile_N;
42 
43  // Check if innermost reduce dimension is the last dimension (stride 1).
44  constexpr auto innermost_reduce_dim = ReduceDims{}.at(number<ReduceDims{}.size() - 1>{});
45  constexpr bool is_innermost_contiguous = (innermost_reduce_dim == InputShape{}.size() - 1);
46 
47  // If innermost reduce dimension is not the last dim (not contiguous), limit vectorization
48  constexpr index_t stride_based_vector_size =
49  is_innermost_contiguous ? ck_tile::min(memory_vector_size, thread_tile_vector_size) : 1;
50 
51  return stride_based_vector_size;
52  }
53 
54  // Helper function to calculate optimal vector size for output tensor
55  static constexpr index_t CalculateOutputVectorSize()
56  {
57  using S = typename Problem::BlockShape;
58  constexpr index_t memory_vector_size = 16 / sizeof(YDataType);
59  constexpr index_t thread_tile_vector_size = S::ThreadTile_M;
60  constexpr index_t vector_size = ck_tile::min(memory_vector_size, thread_tile_vector_size);
61 
62  return vector_size;
63  }
64 
65  public:
66  template <typename InputShape, typename InputStrides, typename KeptDim, typename ReduceDims>
68  YDataType* p_y,
69  InputShape input_shape,
70  InputStrides input_strides,
71  KeptDim kept_dim,
72  ReduceDims reduce_dims) const
73  {
74  using S = typename Problem::BlockShape;
75  const auto iM = get_block_id() * S::Block_M;
76 
77  static_assert(kept_dim.size() + reduce_dims.size() == InputShape::size(),
78  "Size of kept dimensions + reduced dimensions must equal input tensor rank");
79 
80  // Extract lengths based on kept and reduced dimensions
81  const auto kept_lens = [&]() {
82  return generate_tuple([&](auto I) { return input_shape.at(number<kept_dim.at(I)>{}); },
83  number<kept_dim.size()>{});
84  }();
85  const auto reduce_lens = [&]() {
86  return generate_tuple(
87  [&](auto I) { return input_shape.at(number<reduce_dims.at(I)>{}); },
88  number<reduce_dims.size()>{});
89  }();
90 
91  const auto kept_merge_transform = make_merge_transform(kept_lens);
92  const auto reduce_merge_transform = make_merge_transform(reduce_lens);
93 
94  auto reduce_func = typename Problem::ReduceOp{};
95  const XDataType custom_padding_value =
96  type_convert<XDataType>(reduce_func.template GetIdentityValue<ComputeDataType>());
97 
98  // Calculate optimal vector size for input tensor
99  constexpr auto x_tensor_vector_size = CalculateInputVectorSize<InputShape, ReduceDims>();
100 
101  // Create input tensor view with custom padding value
102  auto desc = make_naive_tensor_descriptor(
103  input_shape, input_strides, number<x_tensor_vector_size>{}, number<1>{});
104 
105  // Create buffer view with custom padding value
106  auto buffer_view = make_buffer_view<address_space_enum::global>(
107  p_x, desc.get_element_space_size(), custom_padding_value);
108 
109  // Create tensor view with custom padding
110  const auto x_tensor = tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
111  const auto transformed_x_tensor = pad_tensor_view(
112  transform_tensor_view(x_tensor,
113  make_tuple(kept_merge_transform, reduce_merge_transform),
114  make_tuple(kept_dim, reduce_dims),
117  sequence<0, 1>{});
118 
119  // Calculate strides for output tensor based on its own dimensions
120  const auto kept_strides = [&]() {
121  return generate_tuple(
122  [&](auto I) {
123  // Calculate stride for dimension I as product of all following dimensions
124  index_t stride = 1;
125  static_for<I + 1, kept_dim.size(), 1>{}(
126  [&](auto J) { stride *= kept_lens.at(number<J>{}); });
127  return stride;
128  },
129  number<kept_dim.size()>{});
130  }();
131 
132  // Calculate optimal vector size for output tensor
133  constexpr auto y_tensor_vector_size = CalculateOutputVectorSize();
134 
135  const auto y_m = make_naive_tensor_view<address_space_enum::global>(
136  p_y, kept_lens, kept_strides, number<y_tensor_vector_size>{}, number<1>{});
137 
138  // Transform output tensor to 1D merged view
139  // This creates a view compatible with the 2D reduction pattern
140  const auto y_merged = transform_tensor_view(
141  y_m,
142  make_tuple(kept_merge_transform),
143  make_tuple(typename arithmetic_sequence_gen<0, kept_dim.size(), 1>::type{}),
145 
146  auto x_window = make_tile_window(transformed_x_tensor,
148  {iM, 0},
149  Policy::template MakeXBlockTileDistribution<Problem>());
150 
151  auto y_window = make_tile_window(y_merged, make_tuple(number<S::Block_M>{}), {iM});
152 
153  __shared__ char smem[Policy::template GetSmemSize<Problem>()];
154 
155  // Get the merged dimension size from the transformed tensor
156  const auto merged_reduce_len =
157  transformed_x_tensor.get_tensor_descriptor().get_lengths().at(number<1>{});
158  index_t num_n_tile_iteration =
159  amd_wave_read_first_lane(integer_divide_ceil(merged_reduce_len, S::Block_N));
160 
161  auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
162  auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
163  auto block_reduce2d_cross_warp_sync =
164  Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
165 
166  using XTensorType = decltype(load_tile(x_window));
167  auto y_compute = block_reduce2d.template MakeYBlockTile<XTensorType>();
168  set_tile(y_compute, reduce_func.template GetIdentityValue<ComputeDataType>());
169 
170  for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN)
171  {
172  const auto x = load_tile(x_window);
173  block_reduce2d(x, y_compute, reduce_func);
174  move_tile_window(x_window, {0, S::Block_N});
175  }
176 
177  block_reduce2d_sync(y_compute, reduce_func);
178  block_reduce2d_cross_warp_sync(y_compute, smem, reduce_func);
179 
180  store_tile(y_window, cast_tile<YDataType>(y_compute));
181  }
182 
198  template <typename InputStrides>
199  CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim,
200  InputStrides input_strides)
201  {
202  using S = typename Problem::BlockShape;
203 
204  if(y_continous_dim % S::ThreadTile_N != 0)
205  {
206  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
207  {
208  CK_TILE_ERROR("Total reduction size should be a multiple of ThreadTile_N!");
209  }
210  return false;
211  }
212 
213  if(input_strides.at(number<input_strides.size() - 1>{}) != 1)
214  {
215  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
216  {
218  "Input tensor's last stride must be 1 to support correct vector access!");
219  }
220  return false;
221  }
222 
223  return true;
224  }
225 };
226 
227 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
Definition: reduce2d_kernel.hpp:20
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: reduce2d_kernel.hpp:21
static constexpr CK_TILE_HOST auto BlockSize()
Definition: reduce2d_kernel.hpp:29
ck_tile::remove_cvref_t< typename Problem::YDataType > YDataType
Definition: reduce2d_kernel.hpp:26
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: reduce2d_kernel.hpp:22
static constexpr index_t kBlockSize
Definition: reduce2d_kernel.hpp:28
static CK_TILE_HOST bool IsSupportedArgument(index_t y_continous_dim, InputStrides input_strides)
Validates if the given arguments are supported by the 2D reduction kernel.
Definition: reduce2d_kernel.hpp:199
CK_TILE_DEVICE void operator()(const XDataType *p_x, YDataType *p_y, InputShape input_shape, InputStrides input_strides, KeptDim kept_dim, ReduceDims reduce_dims) const
Definition: reduce2d_kernel.hpp:67
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: reduce2d_kernel.hpp:25
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition: reduce2d_kernel.hpp:24
Definition: sequence.hpp:287
Definition: buffer_view.hpp:38
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define CK_TILE_ENV(name)
Definition: env.hpp:145