/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp Source File
naive_grouped_conv_fwd_gpu.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <array>
11 #include <hip/hip_runtime.h>
12 
13 namespace ck_tile {
14 
15 // Naive GPU reference kernel struct for forward grouped convolution
16 // Layout: Input=NDHWGC, Weight=GKZYXC, Output=NDHWGK (for 3D case)
17 // Input=NHWGC, Weight=GKYXC, Output=NHWGK (for 2D case)
18 // Input=NWGC, Weight=GKXC, Output=NWGK (for 1D case)
19 //
20 // One thread per output element, uses grid-stride loop pattern
21 
22 template <ck_tile::index_t NDimSpatial,
23  typename InDataType,
24  typename WeiDataType,
25  typename OutDataType>
27 {
28  static constexpr ck_tile::index_t kBlockSize = 256;
29 
30  __device__ void
31  operator()(const InDataType* __restrict__ p_in,
32  const WeiDataType* __restrict__ p_wei,
33  OutDataType* __restrict__ p_out,
34  // Tensor dimensions
35  ck_tile::index_t G, // number of groups
36  ck_tile::index_t N, // batch size
37  ck_tile::index_t K, // output channels per group
38  ck_tile::index_t C, // input channels per group
39  // Input spatial dimensions
40  const std::array<ck_tile::long_index_t, NDimSpatial>& in_spatial_lengths,
41  // Weight spatial dimensions
42  const std::array<ck_tile::long_index_t, NDimSpatial>& wei_spatial_lengths,
43  // Output spatial dimensions
44  const std::array<ck_tile::long_index_t, NDimSpatial>& out_spatial_lengths,
45  // Convolution parameters
46  const std::array<ck_tile::long_index_t, NDimSpatial>& conv_strides,
47  const std::array<ck_tile::long_index_t, NDimSpatial>& conv_dilations,
48  const std::array<ck_tile::long_index_t, NDimSpatial>& in_left_pads) const
49  {
50  const ck_tile::long_index_t tid = get_block_id() * blockDim.x + get_thread_id();
51  const ck_tile::long_index_t num_threads = blockDim.x * gridDim.x;
52 
53  // Calculate total output elements
54  ck_tile::long_index_t output_length = G * N * K;
55  for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
56  {
57  output_length *= out_spatial_lengths[i];
58  }
59 
60  // Calculate strides for output tensor (NDHWGK or NHWGK or NWGK)
61  std::array<ck_tile::long_index_t, NDimSpatial + 3> out_strides; // N, spatial dims, G, K
62  ck_tile::long_index_t stride = 1;
63  out_strides[NDimSpatial + 2] = stride; // K stride
64  stride *= K;
65  out_strides[NDimSpatial + 1] = stride; // G stride
66  stride *= G;
67  for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i) // Spatial strides (reversed)
68  {
69  out_strides[i + 1] = stride;
70  stride *= out_spatial_lengths[i];
71  }
72  out_strides[0] = stride; // N stride
73 
74  // Calculate strides for input tensor (NDHWGC or NHWGC or NWGC)
75  std::array<ck_tile::long_index_t, NDimSpatial + 3> in_strides;
76  stride = 1;
77  in_strides[NDimSpatial + 2] = stride; // C stride
78  stride *= C;
79  in_strides[NDimSpatial + 1] = stride; // G stride
80  stride *= G;
81  for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
82  {
83  in_strides[i + 1] = stride;
84  stride *= in_spatial_lengths[i];
85  }
86  in_strides[0] = stride; // N stride
87 
88  // Calculate strides for weight tensor (GKZYXC or GKYXC or GKXC)
89  std::array<ck_tile::long_index_t, NDimSpatial + 3> wei_strides;
90  stride = 1;
91  wei_strides[NDimSpatial + 2] = stride; // C stride
92  stride *= C;
93  for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
94  {
95  wei_strides[i + 2] = stride;
96  stride *= wei_spatial_lengths[i];
97  }
98  wei_strides[1] = stride; // K stride
99  stride *= K;
100  wei_strides[0] = stride; // G stride
101 
102  // Grid-stride loop over all output elements
103  for(ck_tile::long_index_t ii = tid; ii < output_length; ii += num_threads)
104  {
105  // Decode linear index to multi-dimensional indices
106  ck_tile::long_index_t tmp = ii;
107 
108  // Extract N (batch)
109  ck_tile::index_t n = tmp / out_strides[0];
110  tmp -= n * out_strides[0];
111 
112  // Extract spatial dimensions (D, H, W)
113  ck_tile::index_t out_spatial_idx[6]; // Max 6 spatial dimensions
114  for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
115  {
116  out_spatial_idx[i] = tmp / out_strides[i + 1];
117  tmp -= out_spatial_idx[i] * out_strides[i + 1];
118  }
119 
120  // Extract G (group)
121  ck_tile::index_t g = tmp / out_strides[NDimSpatial + 1];
122  tmp -= g * out_strides[NDimSpatial + 1];
123 
124  // Extract K (output channel)
125  ck_tile::index_t k = tmp;
126 
127  // Accumulate in float
128  float v_acc = 0.0f;
129 
130  // Loop over input channels
131  for(ck_tile::index_t c = 0; c < C; ++c)
132  {
133  // Loop over filter spatial dimensions
134  if constexpr(NDimSpatial == 1)
135  {
136  for(ck_tile::index_t x = 0; x < wei_spatial_lengths[0]; ++x)
137  {
138  // Calculate input spatial coordinate
140  static_cast<ck_tile::long_index_t>(out_spatial_idx[0] *
141  conv_strides[0]) +
142  static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
143  static_cast<ck_tile::long_index_t>(in_left_pads[0]);
144 
145  // Bounds check
146  if(wi >= 0 && wi < in_spatial_lengths[0])
147  {
148  std::array<ck_tile::index_t, 1> in_spatial = {static_cast<index_t>(wi)};
149  std::array<ck_tile::index_t, 1> wei_spatial = {x};
150  ck_tile::long_index_t in_idx =
151  detail::calculate_input_index<1>(n, g, c, in_spatial, in_strides);
152  ck_tile::long_index_t wei_idx = detail::calculate_weight_index<1>(
153  g, k, c, wei_spatial, wei_strides);
154 
155  v_acc += type_convert<float>(p_in[in_idx]) *
156  type_convert<float>(p_wei[wei_idx]);
157  }
158  }
159  }
160  else if constexpr(NDimSpatial == 2)
161  {
162  for(ck_tile::index_t y = 0; y < wei_spatial_lengths[0]; ++y)
163  {
165  static_cast<ck_tile::long_index_t>(out_spatial_idx[0] *
166  conv_strides[0]) +
167  static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
168  static_cast<ck_tile::long_index_t>(in_left_pads[0]);
169 
170  for(ck_tile::index_t x = 0; x < wei_spatial_lengths[1]; ++x)
171  {
173  static_cast<ck_tile::long_index_t>(out_spatial_idx[1] *
174  conv_strides[1]) +
175  static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
176  static_cast<ck_tile::long_index_t>(in_left_pads[1]);
177 
178  // Bounds check
179  if(hi >= 0 && hi < in_spatial_lengths[0] && wi >= 0 &&
180  wi < in_spatial_lengths[1])
181  {
182  std::array<ck_tile::index_t, 2> in_spatial = {
183  static_cast<index_t>(hi), static_cast<index_t>(wi)};
184  std::array<ck_tile::index_t, 2> wei_spatial = {y, x};
185  ck_tile::long_index_t in_idx = detail::calculate_input_index<2>(
186  n, g, c, in_spatial, in_strides);
187  ck_tile::long_index_t wei_idx = detail::calculate_weight_index<2>(
188  g, k, c, wei_spatial, wei_strides);
189 
190  v_acc += type_convert<float>(p_in[in_idx]) *
191  type_convert<float>(p_wei[wei_idx]);
192  }
193  }
194  }
195  }
196  else if constexpr(NDimSpatial == 3)
197  {
198  for(ck_tile::index_t z = 0; z < wei_spatial_lengths[0]; ++z)
199  {
201  static_cast<ck_tile::long_index_t>(out_spatial_idx[0] *
202  conv_strides[0]) +
203  static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
204  static_cast<ck_tile::long_index_t>(in_left_pads[0]);
205 
206  for(ck_tile::index_t y = 0; y < wei_spatial_lengths[1]; ++y)
207  {
209  static_cast<ck_tile::long_index_t>(out_spatial_idx[1] *
210  conv_strides[1]) +
211  static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
212  static_cast<ck_tile::long_index_t>(in_left_pads[1]);
213 
214  for(ck_tile::index_t x = 0; x < wei_spatial_lengths[2]; ++x)
215  {
217  static_cast<ck_tile::long_index_t>(out_spatial_idx[2] *
218  conv_strides[2]) +
219  static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
220  static_cast<ck_tile::long_index_t>(in_left_pads[2]);
221 
222  // Bounds check
223  if(di >= 0 && di < in_spatial_lengths[0] && hi >= 0 &&
224  hi < in_spatial_lengths[1] && wi >= 0 &&
225  wi < in_spatial_lengths[2])
226  {
227  std::array<ck_tile::index_t, 3> in_spatial = {
228  static_cast<index_t>(di),
229  static_cast<index_t>(hi),
230  static_cast<index_t>(wi)};
231  std::array<ck_tile::index_t, 3> wei_spatial = {z, y, x};
232  ck_tile::long_index_t in_idx = detail::calculate_input_index<3>(
233  n, g, c, in_spatial, in_strides);
234  ck_tile::long_index_t wei_idx =
235  detail::calculate_weight_index<3>(
236  g, k, c, wei_spatial, wei_strides);
237 
238  v_acc += type_convert<float>(p_in[in_idx]) *
239  type_convert<float>(p_wei[wei_idx]);
240  }
241  }
242  }
243  }
244  }
245  }
246 
247  // Convert accumulator to output type and write
248  p_out[ii] = type_convert<OutDataType>(v_acc);
249  }
250  }
251 };
252 
253 // Host-side launcher for naive grouped convolution forward
254 template <ck_tile::index_t NDimSpatial,
255  typename InDataType,
256  typename WeiDataType,
257  typename OutDataType>
258 CK_TILE_HOST float naive_grouped_conv_fwd(const InDataType* p_in_dev,
259  const WeiDataType* p_wei_dev,
260  OutDataType* p_out_dev,
265  std::vector<ck_tile::long_index_t> in_spatial_lengths,
266  std::vector<ck_tile::long_index_t> wei_spatial_lengths,
267  std::vector<ck_tile::long_index_t> out_spatial_lengths,
268  std::vector<ck_tile::long_index_t> conv_strides,
269  std::vector<ck_tile::long_index_t> conv_dilations,
270  std::vector<ck_tile::long_index_t> in_left_pads,
272 {
273  // Convert vectors to arrays (std::array can be passed by value to kernel)
274  auto in_spatial_arr = to_array_with_default<NDimSpatial>(in_spatial_lengths);
275  auto wei_spatial_arr = to_array_with_default<NDimSpatial>(wei_spatial_lengths);
276  auto out_spatial_arr = to_array_with_default<NDimSpatial>(out_spatial_lengths);
277  auto conv_strides_arr = to_array_with_default<NDimSpatial>(conv_strides);
278  auto conv_dilations_arr = to_array_with_default<NDimSpatial>(conv_dilations);
279  auto in_left_pads_arr = to_array_with_default<NDimSpatial>(in_left_pads, 0);
280 
281  // Calculate grid size
282  ck_tile::long_index_t output_length = G * N * K;
283  for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
284  {
285  output_length *= out_spatial_lengths[i];
286  }
287 
288  using KernelType =
289  naive_grouped_conv_fwd_kernel<NDimSpatial, InDataType, WeiDataType, OutDataType>;
290 
291  constexpr ck_tile::index_t block_size = KernelType::kBlockSize;
292  const ck_tile::index_t grid_size = (output_length + block_size - 1) / block_size;
293 
294  // Launch kernel
295  float elapsed_ms = launch_kernel(stream_config,
296  make_kernel(KernelType{},
297  dim3(grid_size),
298  dim3(block_size),
299  0, // dynamic shared memory size
300  p_in_dev,
301  p_wei_dev,
302  p_out_dev,
303  G,
304  N,
305  K,
306  C,
307  in_spatial_arr,
308  wei_spatial_arr,
309  out_spatial_arr,
310  conv_strides_arr,
311  conv_dilations_arr,
312  in_left_pads_arr));
313 
314  return elapsed_ms;
315 }
316 
317 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_HOST auto make_kernel(KernelImpl, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:60
CK_TILE_HOST float naive_grouped_conv_fwd(const InDataType *p_in_dev, const WeiDataType *p_wei_dev, OutDataType *p_out_dev, ck_tile::index_t G, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t C, std::vector< ck_tile::long_index_t > in_spatial_lengths, std::vector< ck_tile::long_index_t > wei_spatial_lengths, std::vector< ck_tile::long_index_t > out_spatial_lengths, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, ck_tile::stream_config stream_config={})
Definition: naive_grouped_conv_fwd_gpu.hpp:258
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition: kernel_launch.hpp:173
Definition: naive_grouped_conv_fwd_gpu.hpp:27
__device__ void operator()(const InDataType *__restrict__ p_in, const WeiDataType *__restrict__ p_wei, OutDataType *__restrict__ p_out, ck_tile::index_t G, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t C, const std::array< ck_tile::long_index_t, NDimSpatial > &in_spatial_lengths, const std::array< ck_tile::long_index_t, NDimSpatial > &wei_spatial_lengths, const std::array< ck_tile::long_index_t, NDimSpatial > &out_spatial_lengths, const std::array< ck_tile::long_index_t, NDimSpatial > &conv_strides, const std::array< ck_tile::long_index_t, NDimSpatial > &conv_dilations, const std::array< ck_tile::long_index_t, NDimSpatial > &in_left_pads) const
Definition: naive_grouped_conv_fwd_gpu.hpp:31
static constexpr ck_tile::index_t kBlockSize
Definition: naive_grouped_conv_fwd_gpu.hpp:28
Definition: stream_config.hpp:30