/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp Source File
reference_grouped_conv_fwd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
12 
13 namespace ck_tile {
14 
15 template <ck_tile::index_t NDimSpatial,
16  typename InDataType,
17  typename WeiDataType,
18  typename OutDataType,
19  typename Elfunc = ck_tile::element_wise::PassThrough,
20  typename Tuple = ck_tile::tuple<>>
22  const HostTensor<WeiDataType>& weight,
24  std::vector<ck_tile::long_index_t> conv_strides,
25  std::vector<ck_tile::long_index_t> conv_dilations,
26  std::vector<ck_tile::long_index_t> in_left_pads,
27  std::vector<ck_tile::long_index_t>,
28  Elfunc elfunc = Elfunc{},
29  Tuple ds = {})
30 {
31  if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
32  weight.get_num_of_dimension() == NDimSpatial + 3 &&
33  output.get_num_of_dimension() == NDimSpatial + 3))
34  {
35  throw std::runtime_error("wrong! inconsistent dimension");
36  }
37 
38  if constexpr(NDimSpatial == 1)
39  {
40  auto func = [&](auto g, auto n, auto k, auto wo) {
41  float v_acc = 0;
42 
43  for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
44  {
45  for(std::size_t x = 0; x < weight.get_lengths()[3]; ++x)
46  {
47  auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
48  static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
49  static_cast<ck_tile::long_index_t>(in_left_pads[0]);
50 
51  if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
52  {
53  InDataType v_in = input(g, n, c, wi);
54  WeiDataType v_wei = weight(g, k, c, x);
55  v_acc += ck_tile::type_convert<float>(v_in) *
56  ck_tile::type_convert<float>(v_wei);
57  }
58  }
59  }
60  if constexpr(Tuple::size() > 0)
61  elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, wo));
62  else
63  elfunc(v_acc, v_acc);
64  OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
65  output(g, n, k, wo) = v_acc_out;
66  };
67 
69  output.get_lengths()[0],
70  output.get_lengths()[1],
71  output.get_lengths()[2],
72  output.get_lengths()[3])(std::thread::hardware_concurrency());
73  }
74  else if constexpr(NDimSpatial == 2)
75  {
76  auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
77  float v_acc = 0;
78 
79  for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
80  {
81  for(std::size_t y = 0; y < weight.get_lengths()[3]; ++y)
82  {
83  auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
84  static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
85  static_cast<ck_tile::long_index_t>(in_left_pads[0]);
86 
87  for(std::size_t x = 0; x < weight.get_lengths()[4]; ++x)
88  {
89  auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
90  static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
91  static_cast<ck_tile::long_index_t>(in_left_pads[1]);
92 
93  if(hi >= 0 &&
94  ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
95  wi >= 0 &&
96  ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
97  {
98  InDataType v_in = input(g, n, c, hi, wi);
99  WeiDataType v_wei = weight(g, k, c, y, x);
100 
101  v_acc += ck_tile::type_convert<float>(v_in) *
102  ck_tile::type_convert<float>(v_wei);
103  }
104  }
105  }
106  }
107  if constexpr(Tuple::size() > 0)
108  elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, ho, wo));
109  else
110  elfunc(v_acc, v_acc);
111  OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
112  output(g, n, k, ho, wo) = v_acc_out;
113  };
114 
116  output.get_lengths()[0],
117  output.get_lengths()[1],
118  output.get_lengths()[2],
119  output.get_lengths()[3],
120  output.get_lengths()[4])(std::thread::hardware_concurrency());
121  }
122  else if constexpr(NDimSpatial == 3)
123  {
124  auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
125  float v_acc = 0;
126 
127  for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
128  {
129  for(std::size_t z = 0; z < weight.get_lengths()[3]; ++z)
130  {
131  auto di = static_cast<ck_tile::long_index_t>(d_o * conv_strides[0]) +
132  static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
133  static_cast<ck_tile::long_index_t>(in_left_pads[0]);
134  for(std::size_t y = 0; y < weight.get_lengths()[4]; ++y)
135  {
136  auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
137  static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
138  static_cast<ck_tile::long_index_t>(in_left_pads[1]);
139  for(std::size_t x = 0; x < weight.get_lengths()[5]; ++x)
140  {
141  auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
142  static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
143  static_cast<ck_tile::long_index_t>(in_left_pads[2]);
144  if(di >= 0 &&
145  ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
146  hi >= 0 &&
147  ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
148  wi >= 0 &&
149  ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
150  {
151  InDataType v_in = input(g, n, c, di, hi, wi);
152  WeiDataType v_wei = weight(g, k, c, z, y, x);
153 
154  v_acc += ck_tile::type_convert<float>(v_in) *
155  ck_tile::type_convert<float>(v_wei);
156  }
157  }
158  }
159  }
160  }
161  if constexpr(Tuple::size() > 0)
162  elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, d_o, ho, wo));
163  else
164  elfunc(v_acc, v_acc);
165  OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
166  output(g, n, k, d_o, ho, wo) = v_acc_out;
167  };
168 
170  output.get_lengths()[0],
171  output.get_lengths()[1],
172  output.get_lengths()[2],
173  output.get_lengths()[3],
174  output.get_lengths()[4],
175  output.get_lengths()[5])(std::thread::hardware_concurrency());
176  }
177  else
178  {
179  throw std::runtime_error("Ref_Conv_fwd: number of dimensions must be between 1 and 3.");
180  }
181 }
182 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
constant< v > number
Definition: integral_constant.hpp:37
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor< InDataType > &input, const HostTensor< WeiDataType > &weight, HostTensor< OutDataType > &output, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, std::vector< ck_tile::long_index_t >, Elfunc elfunc=Elfunc{}, Tuple ds={})
Definition: reference_grouped_conv_fwd.hpp:21
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396
Definition: tuple.hpp:192