/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp Source File
reference_grouped_conv_fwd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <ck_tile::index_t NDimSpatial,
15  typename InDataType,
16  typename WeiDataType,
17  typename OutDataType>
19  const HostTensor<WeiDataType>& weight,
21  std::vector<ck_tile::long_index_t> conv_strides,
22  std::vector<ck_tile::long_index_t> conv_dilations,
23  std::vector<ck_tile::long_index_t> in_left_pads,
24  std::vector<ck_tile::long_index_t>)
25 {
26  if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
27  weight.get_num_of_dimension() == NDimSpatial + 3 &&
28  output.get_num_of_dimension() == NDimSpatial + 3))
29  {
30  throw std::runtime_error("wrong! inconsistent dimension");
31  }
32 
33  if constexpr(NDimSpatial == 1)
34  {
35  auto func = [&](auto g, auto n, auto k, auto wo) {
36  float v_acc = 0;
37 
38  for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
39  {
40  for(std::size_t x = 0; x < weight.get_lengths()[3]; ++x)
41  {
42  auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
43  static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
44  static_cast<ck_tile::long_index_t>(in_left_pads[0]);
45 
46  if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
47  {
48  InDataType v_in = input(g, n, c, wi);
49  WeiDataType v_wei = weight(g, k, c, x);
50  v_acc += ck_tile::type_convert<float>(v_in) *
51  ck_tile::type_convert<float>(v_wei);
52  }
53  }
54  }
55  OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
56  output(g, n, k, wo) = v_acc_converted;
57  };
58 
60  output.get_lengths()[0],
61  output.get_lengths()[1],
62  output.get_lengths()[2],
63  output.get_lengths()[3])(std::thread::hardware_concurrency());
64  }
65  else if constexpr(NDimSpatial == 2)
66  {
67  auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
68  float v_acc = 0;
69 
70  for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
71  {
72  for(std::size_t y = 0; y < weight.get_lengths()[3]; ++y)
73  {
74  auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
75  static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
76  static_cast<ck_tile::long_index_t>(in_left_pads[0]);
77 
78  for(std::size_t x = 0; x < weight.get_lengths()[4]; ++x)
79  {
80  auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
81  static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
82  static_cast<ck_tile::long_index_t>(in_left_pads[1]);
83 
84  if(hi >= 0 &&
85  ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
86  wi >= 0 &&
87  ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
88  {
89  InDataType v_in = input(g, n, c, hi, wi);
90  WeiDataType v_wei = weight(g, k, c, y, x);
91 
92  v_acc += ck_tile::type_convert<float>(v_in) *
93  ck_tile::type_convert<float>(v_wei);
94  }
95  }
96  }
97  }
98  OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
99  output(g, n, k, ho, wo) = v_acc_converted;
100  };
101 
103  output.get_lengths()[0],
104  output.get_lengths()[1],
105  output.get_lengths()[2],
106  output.get_lengths()[3],
107  output.get_lengths()[4])(std::thread::hardware_concurrency());
108  }
109  else if constexpr(NDimSpatial == 3)
110  {
111  auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
112  float v_acc = 0;
113 
114  for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
115  {
116  for(std::size_t z = 0; z < weight.get_lengths()[3]; ++z)
117  {
118  auto di = static_cast<ck_tile::long_index_t>(d_o * conv_strides[0]) +
119  static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
120  static_cast<ck_tile::long_index_t>(in_left_pads[0]);
121  for(std::size_t y = 0; y < weight.get_lengths()[4]; ++y)
122  {
123  auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
124  static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
125  static_cast<ck_tile::long_index_t>(in_left_pads[1]);
126  for(std::size_t x = 0; x < weight.get_lengths()[5]; ++x)
127  {
128  auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
129  static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
130  static_cast<ck_tile::long_index_t>(in_left_pads[2]);
131  if(di >= 0 &&
132  ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
133  hi >= 0 &&
134  ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
135  wi >= 0 &&
136  ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
137  {
138  InDataType v_in = input(g, n, c, di, hi, wi);
139  WeiDataType v_wei = weight(g, k, c, z, y, x);
140 
141  v_acc += ck_tile::type_convert<float>(v_in) *
142  ck_tile::type_convert<float>(v_wei);
143  }
144  }
145  }
146  }
147  }
148  OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
149  output(g, n, k, d_o, ho, wo) = v_acc_converted;
150  };
151 
153  output.get_lengths()[0],
154  output.get_lengths()[1],
155  output.get_lengths()[2],
156  output.get_lengths()[3],
157  output.get_lengths()[4],
158  output.get_lengths()[5])(std::thread::hardware_concurrency());
159  }
160  else
161  {
162  throw std::runtime_error("Ref_Conv_fwd: number of dimensions must be between 1 and 3.");
163  }
164 }
165 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor< InDataType > &input, const HostTensor< WeiDataType > &weight, HostTensor< OutDataType > &output, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, std::vector< ck_tile::long_index_t >)
Definition: reference_grouped_conv_fwd.hpp:18
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396