/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_pool.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_pool.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_pool.hpp Source File
reference_pool.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 #include <thread>
10 #include <cmath>
11 
12 namespace ck_tile {
13 
14 template <typename InDataType,
15  typename ComputeDataType,
16  typename OutDataType,
17  typename IndexDataType,
18  typename ReduceOp,
19  typename TensorShape,
20  typename WindowShape,
21  bool OutputIndex = false>
24  HostTensor<IndexDataType>& output_index,
26  ReduceOp reduce_op)
27 {
28  const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
29  const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<1>{});
30  const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<2>{});
31  const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<3>{});
32 
33  const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<1>{});
34  const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<2>{});
35 
38 
39  const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<0>{});
40  const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<1>{});
41 
44 
45  const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<0>{});
46  const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<1>{});
47  // Right padding is handled implicitly by bounds checking
48 
49  auto f = [&](auto n, auto ho, auto wo, auto c) {
50  ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
51 
52  IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
53 
54  for(ck_tile::index_t y = 0; y < Y; ++y)
55  {
56  // Calculate input height index with stride, dilation, and padding
57  ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
58 
59  for(ck_tile::index_t x = 0; x < X; ++x)
60  {
61  // Calculate input width index with stride, dilation, and padding
62  ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
63 
64  if(hi >= 0 && hi < H && wi >= 0 && wi < W)
65  {
66  const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
67 
68  if constexpr(OutputIndex)
69  {
70  IndexDataType flat_index = input.GetOffsetFromMultiIndex(n, hi, wi, c);
71  bool changed = false;
72  v_acc = reduce_op(v_acc, v_in, changed);
73  if(changed)
74  {
75  current_index = flat_index;
76  }
77  }
78  else
79  {
80  v_acc = reduce_op(v_acc, v_in);
81  }
82  }
83  // For positions outside bounds, we implicitly use identity value
84  }
85  }
86 
87  output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
88 
89  if constexpr(OutputIndex)
90  {
91  output_index(n, ho, wo, c) = current_index;
92  }
93  };
94 
95  // Parallelize over all output dimensions
96  make_ParallelTensorFunctor(f, N, Ho, Wo, C)(std::thread::hardware_concurrency());
97 }
98 
99 template <typename InDataType,
100  typename ComputeDataType,
101  typename OutDataType,
102  typename IndexDataType,
103  typename ReduceOp,
104  typename TensorShape,
105  typename WindowShape,
106  bool OutputIndex = false>
108  HostTensor<OutDataType>& output,
109  HostTensor<IndexDataType>& output_index,
111  ReduceOp reduce_op)
112 {
113  const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
114  const ck_tile::index_t D = kargs.input_shape.at(ck_tile::number<1>{});
115  const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<2>{});
116  const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<3>{});
117  const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<4>{});
118 
119  const ck_tile::index_t Do = kargs.output_shape.at(ck_tile::number<1>{});
120  const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<2>{});
121  const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<3>{});
122 
123  const ck_tile::index_t Z = kargs.window_lengths.at(ck_tile::number<0>{});
124  const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<1>{});
125  const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<2>{});
126 
127  const ck_tile::index_t Sz = kargs.window_strides.at(ck_tile::number<0>{});
128  const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<1>{});
129  const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<2>{});
130 
134 
135  const ck_tile::index_t LeftPz = kargs.input_left_pads.at(ck_tile::number<0>{});
136  const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<1>{});
137  const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<2>{});
138  // Right padding is handled implicitly by bounds checking
139 
140  auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
141  ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
142 
143  IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
144 
145  for(ck_tile::index_t z = 0; z < Z; ++z)
146  {
147  // Calculate input depth index with stride, dilation, and padding
148  ck_tile::index_t di = do_ * Sz + z * Dz - LeftPz;
149 
150  for(ck_tile::index_t y = 0; y < Y; ++y)
151  {
152  // Calculate input height index with stride, dilation, and padding
153  ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
154 
155  for(ck_tile::index_t x = 0; x < X; ++x)
156  {
157  // Calculate input width index with stride, dilation, and padding
158  ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
159 
160  if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W)
161  {
162  const ComputeDataType v_in =
163  type_convert<ComputeDataType>(input(n, di, hi, wi, c));
164 
165  if constexpr(OutputIndex)
166  {
167  IndexDataType flat_index =
168  input.GetOffsetFromMultiIndex(n, di, hi, wi, c);
169  bool changed = false;
170  v_acc = reduce_op(v_acc, v_in, changed);
171  if(changed)
172  {
173  current_index = flat_index;
174  }
175  }
176  else
177  {
178  v_acc = reduce_op(v_acc, v_in);
179  }
180  }
181  // For positions outside bounds, we implicitly use identity value
182  }
183  }
184  }
185 
186  output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
187 
188  if constexpr(OutputIndex)
189  {
190 
191  output_index(n, do_, ho, wo, c) = current_index;
192  }
193  };
194 
195  // Parallelize over all output dimensions
196  make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
197 }
198 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_HOST void reference_pool2d(const HostTensor< InDataType > &input, HostTensor< OutDataType > &output, HostTensor< IndexDataType > &output_index, PoolKernelArgs< TensorShape, WindowShape > kargs, ReduceOp reduce_op)
Definition: reference_pool.hpp:22
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_pool3d(const HostTensor< InDataType > &input, HostTensor< OutDataType > &output, HostTensor< IndexDataType > &output_index, PoolKernelArgs< TensorShape, WindowShape > kargs, ReduceOp reduce_op)
Definition: reference_pool.hpp:107
Definition: host_tensor.hpp:336
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition: host_tensor.hpp:531
Kernel arguments for pooling operations.
Definition: pool_kernel.hpp:63
TensorShape output_shape
Definition: pool_kernel.hpp:68
WindowShape window_lengths
Definition: pool_kernel.hpp:71
WindowShape window_dilations
Definition: pool_kernel.hpp:73
WindowShape input_left_pads
Definition: pool_kernel.hpp:74
TensorShape input_shape
Definition: pool_kernel.hpp:67
WindowShape window_strides
Definition: pool_kernel.hpp:72
Definition: integral_constant.hpp:13