/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp Source File
reference_grouped_conv_bwd_weight.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <ck_tile::index_t NDimSpatial,
15  typename InDataType,
16  typename WeiDataType,
17  typename OutDataType>
18 CK_TILE_HOST void
21  const HostTensor<OutDataType>& output,
22  std::vector<ck_tile::long_index_t> conv_strides,
23  std::vector<ck_tile::long_index_t> conv_dilations,
24  std::vector<ck_tile::long_index_t> in_left_pads,
25  std::vector<ck_tile::long_index_t>)
26 {
27  if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
28  weight.get_num_of_dimension() == NDimSpatial + 3 &&
29  output.get_num_of_dimension() == NDimSpatial + 3))
30  {
31  throw std::runtime_error("wrong! inconsistent dimension");
32  }
33 
34  if constexpr(NDimSpatial == 1)
35  {
36  auto func = [&](auto g, auto k, auto c, auto x) {
37  float v_acc = 0;
38 
39  for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
40  {
41  for(std::size_t wo = 0; wo < output.get_lengths()[3]; ++wo)
42  {
43  auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
44  static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
45  static_cast<ck_tile::long_index_t>(in_left_pads[0]);
46 
47  if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
48  {
49  InDataType v_in = input(g, n, c, wi);
50  OutDataType v_out = output(g, n, k, wo);
51  v_acc += ck_tile::type_convert<float>(v_out) *
52  ck_tile::type_convert<float>(v_in);
53  }
54  }
55  }
56  OutDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
57  weight(g, k, c, x) = v_acc_converted;
58  };
59 
61  weight.get_lengths()[0],
62  weight.get_lengths()[1],
63  weight.get_lengths()[2],
64  weight.get_lengths()[3])(std::thread::hardware_concurrency());
65  }
66  else if constexpr(NDimSpatial == 2)
67  {
68  auto func = [&](auto g, auto k, auto c, auto y, auto x) {
69  float v_acc = 0;
70 
71  for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
72  {
73  for(std::size_t ho = 0; ho < output.get_lengths()[3]; ++ho)
74  {
75  auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
76  static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
77  static_cast<ck_tile::long_index_t>(in_left_pads[0]);
78 
79  for(std::size_t wo = 0; wo < output.get_lengths()[4]; ++wo)
80  {
81  auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
82  static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
83  static_cast<ck_tile::long_index_t>(in_left_pads[1]);
84 
85  if(hi >= 0 &&
86  ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
87  wi >= 0 &&
88  ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
89  {
90  InDataType v_in = input(g, n, c, hi, wi);
91  OutDataType v_out = output(g, n, k, ho, wo);
92 
93  v_acc += ck_tile::type_convert<float>(v_out) *
94  ck_tile::type_convert<float>(v_in);
95  }
96  }
97  }
98  }
99  WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
100  weight(g, k, c, y, x) = v_acc_converted;
101  };
102 
104  weight.get_lengths()[0],
105  weight.get_lengths()[1],
106  weight.get_lengths()[2],
107  weight.get_lengths()[3],
108  weight.get_lengths()[4])(std::thread::hardware_concurrency());
109  }
110  else if constexpr(NDimSpatial == 3)
111  {
112  auto func = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
113  float v_acc = 0;
114 
115  for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
116  {
117  for(std::size_t do_ = 0; do_ < output.get_lengths()[3]; ++do_)
118  {
119  auto di = static_cast<ck_tile::long_index_t>(do_ * conv_strides[0]) +
120  static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
121  static_cast<ck_tile::long_index_t>(in_left_pads[0]);
122  for(std::size_t ho = 0; ho < output.get_lengths()[4]; ++ho)
123  {
124  auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
125  static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
126  static_cast<ck_tile::long_index_t>(in_left_pads[1]);
127  for(std::size_t wo = 0; wo < output.get_lengths()[5]; ++wo)
128  {
129  auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
130  static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
131  static_cast<ck_tile::long_index_t>(in_left_pads[2]);
132  if(di >= 0 &&
133  ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
134  hi >= 0 &&
135  ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
136  wi >= 0 &&
137  ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
138  {
139  InDataType v_in = input(g, n, c, di, hi, wi);
140  OutDataType v_out = output(g, n, k, do_, ho, wo);
141 
142  v_acc += ck_tile::type_convert<float>(v_out) *
143  ck_tile::type_convert<float>(v_in);
144  }
145  }
146  }
147  }
148  }
149  WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
150  weight(g, k, c, z, y, x) = v_acc_converted;
151  };
152 
154  weight.get_lengths()[0],
155  weight.get_lengths()[1],
156  weight.get_lengths()[2],
157  weight.get_lengths()[3],
158  weight.get_lengths()[4],
159  weight.get_lengths()[5])(std::thread::hardware_concurrency());
160  }
161  else
162  {
163  throw std::runtime_error(
164  "Ref_conv_bwd_weight: number of dimensions must be between 1 and 3.");
165  }
166 }
167 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_HOST void reference_grouped_conv_bwd_weight(const HostTensor< InDataType > &input, HostTensor< WeiDataType > &weight, const HostTensor< OutDataType > &output, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, std::vector< ck_tile::long_index_t >)
Definition: reference_grouped_conv_bwd_weight.hpp:19
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396