/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/convolution_parameter.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/convolution_parameter.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/convolution_parameter.hpp Source File
convolution_parameter.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <numeric>
8 #include <iterator>
9 #include <vector>
10 
11 namespace ck_tile {
12 namespace conv {
13 
14 struct ConvParam
15 {
17  ck_tile::index_t group_count,
18  ck_tile::index_t n_batch,
19  ck_tile::index_t n_out_channels,
20  ck_tile::index_t n_in_channels,
21  const std::vector<ck_tile::index_t>& filters_len,
22  const std::vector<ck_tile::index_t>& input_len,
23  const std::vector<ck_tile::index_t>& strides,
24  const std::vector<ck_tile::index_t>& dilations,
25  const std::vector<ck_tile::index_t>& left_pads,
26  const std::vector<ck_tile::index_t>& right_pads)
27  : num_dim_spatial_(static_cast<ck_tile::long_index_t>(n_dim)),
28  G_(static_cast<ck_tile::long_index_t>(group_count)),
29  N_(static_cast<ck_tile::long_index_t>(n_batch)),
30  K_(static_cast<ck_tile::long_index_t>(n_out_channels)),
31  C_(static_cast<ck_tile::long_index_t>(n_in_channels)),
39  {
40  if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
42  static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
44  static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
45  static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
46  {
47  throw(std::runtime_error(
48  "ConvParam::ConvParam: "
49  "parameter size is different from number of declared dimensions!"));
50  }
51 
52  for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
53  {
54  filter_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(filters_len[i]);
55  input_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(input_len[i]);
56  conv_filter_strides_[i] = static_cast<ck_tile::long_index_t>(strides[i]);
57  conv_filter_dilations_[i] = static_cast<ck_tile::long_index_t>(dilations[i]);
58  input_left_pads_[i] = static_cast<ck_tile::long_index_t>(left_pads[i]);
59  input_right_pads_[i] = static_cast<ck_tile::long_index_t>(right_pads[i]);
60 
61  // XEff = (X - 1) * conv_dilation_w + 1;
62  // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
63  const ck_tile::long_index_t x_eff =
65 
69  1;
70  }
71  }
72 
74  ck_tile::long_index_t group_count,
75  ck_tile::long_index_t n_batch,
76  ck_tile::long_index_t n_out_channels,
77  ck_tile::long_index_t n_in_channels,
78  const std::vector<ck_tile::long_index_t>& filters_len,
79  const std::vector<ck_tile::long_index_t>& input_len,
80  const std::vector<ck_tile::long_index_t>& strides,
81  const std::vector<ck_tile::long_index_t>& dilations,
82  const std::vector<ck_tile::long_index_t>& left_pads,
83  const std::vector<ck_tile::long_index_t>& right_pads)
84  : num_dim_spatial_(n_dim),
85  G_(group_count),
86  N_(n_batch),
87  K_(n_out_channels),
88  C_(n_in_channels),
89  filter_spatial_lengths_(filters_len),
90  input_spatial_lengths_(input_len),
92  conv_filter_strides_(strides),
93  conv_filter_dilations_(dilations),
94  input_left_pads_(left_pads),
95  input_right_pads_(right_pads)
96  {
97  if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
99  static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
100  static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
101  static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
102  static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
103  {
104  throw(std::runtime_error(
105  "ConvParam::ConvParam: "
106  "parameter size is different from number of declared dimensions!"));
107  }
108 
109  for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
110  {
111  // XEff = (X - 1) * conv_dilation_w + 1;
112  // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
113  const ck_tile::long_index_t x_eff =
115 
119  1;
120  }
121  }
122 
128 
129  std::vector<ck_tile::long_index_t> filter_spatial_lengths_;
130  std::vector<ck_tile::long_index_t> input_spatial_lengths_;
131  std::vector<ck_tile::long_index_t> output_spatial_lengths_;
132 
133  std::vector<ck_tile::long_index_t> conv_filter_strides_;
134  std::vector<ck_tile::long_index_t> conv_filter_dilations_;
135 
136  std::vector<ck_tile::long_index_t> input_left_pads_;
137  std::vector<ck_tile::long_index_t> input_right_pads_;
138 
139  std::vector<ck_tile::long_index_t> GetOutputSpatialLengths() const
140  {
142  }
143 
144  std::size_t GetFlops() const
145  {
146  // 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
147  return static_cast<std::size_t>(2) * G_ * N_ * K_ * C_ *
148  std::accumulate(std::begin(output_spatial_lengths_),
149  std::next(std::begin(output_spatial_lengths_), num_dim_spatial_),
150  1,
151  std::multiplies<>()) *
152  std::accumulate(std::begin(filter_spatial_lengths_),
153  std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
154  1,
155  std::multiplies<>());
156  }
157 
158  template <typename InDataType>
159  std::size_t GetInputByte() const
160  {
161  // sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
162  return sizeof(InDataType) *
163  (G_ * N_ * C_ *
164  std::accumulate(std::begin(input_spatial_lengths_),
165  std::next(std::begin(input_spatial_lengths_), num_dim_spatial_),
166  1,
167  std::multiplies<>()));
168  }
169 
170  template <typename WeiDataType>
171  std::size_t GetWeightByte() const
172  {
173  // sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
174  return sizeof(WeiDataType) *
175  (G_ * K_ * C_ *
176  std::accumulate(std::begin(filter_spatial_lengths_),
177  std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
178  1,
179  std::multiplies<>()));
180  }
181 
182  template <typename OutDataType>
183  std::size_t GetOutputByte() const
184  {
185  // sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
186  return sizeof(OutDataType) * (G_ * N_ * K_ *
187  std::accumulate(std::begin(output_spatial_lengths_),
188  std::end(output_spatial_lengths_),
189  static_cast<std::size_t>(1),
190  std::multiplies<std::size_t>()));
191  }
192 
193  template <typename InDataType, typename WeiDataType, typename OutDataType>
194  std::size_t GetByte() const
195  {
196  return GetInputByte<InDataType>() + GetWeightByte<WeiDataType>() +
197  GetOutputByte<OutDataType>();
198  }
199 };
200 
202 {
203  std::string msg;
204 
205  msg += "Following arguments (depending on number of spatial dims):\n"
206  " Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n"
207  " G, N, K, C, \n"
208  " <filter spatial dimensions>, (ie Y, X for 2D)\n"
209  " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
210  " <strides>, (ie Sy, Sx for 2D)\n"
211  " <dilations>, (ie Dy, Dx for 2D)\n"
212  " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
213  " <right padding>, (ie RightPy, RightPx for 2D)\n";
214 
215  return msg;
216 }
217 
219 parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
220 {
221  const ck_tile::long_index_t G = std::stol(argv[arg_idx++]);
222  const ck_tile::long_index_t N = std::stol(argv[arg_idx++]);
223  const ck_tile::long_index_t K = std::stol(argv[arg_idx++]);
224  const ck_tile::long_index_t C = std::stol(argv[arg_idx++]);
225 
226  std::vector<ck_tile::long_index_t> filter_spatial_lengths(num_dim_spatial);
227  std::vector<ck_tile::long_index_t> input_spatial_lengths(num_dim_spatial);
228  std::vector<ck_tile::long_index_t> conv_filter_strides(num_dim_spatial);
229  std::vector<ck_tile::long_index_t> conv_filter_dilations(num_dim_spatial);
230  std::vector<ck_tile::long_index_t> input_left_pads(num_dim_spatial);
231  std::vector<ck_tile::long_index_t> input_right_pads(num_dim_spatial);
232 
233  for(int i = 0; i < num_dim_spatial; ++i)
234  {
235  filter_spatial_lengths[i] = std::stol(argv[arg_idx++]);
236  }
237 
238  for(int i = 0; i < num_dim_spatial; ++i)
239  {
240  input_spatial_lengths[i] = std::stol(argv[arg_idx++]);
241  }
242 
243  for(int i = 0; i < num_dim_spatial; ++i)
244  {
245  conv_filter_strides[i] = std::stol(argv[arg_idx++]);
246  }
247 
248  for(int i = 0; i < num_dim_spatial; ++i)
249  {
250  conv_filter_dilations[i] = std::stol(argv[arg_idx++]);
251  }
252 
253  for(int i = 0; i < num_dim_spatial; ++i)
254  {
255  input_left_pads[i] = std::stol(argv[arg_idx++]);
256  }
257 
258  for(int i = 0; i < num_dim_spatial; ++i)
259  {
260  input_right_pads[i] = std::stol(argv[arg_idx++]);
261  }
262 
263  return ck_tile::conv::ConvParam{num_dim_spatial,
264  G,
265  N,
266  K,
267  C,
268  filter_spatial_lengths,
269  input_spatial_lengths,
270  conv_filter_strides,
271  conv_filter_dilations,
272  input_left_pads,
273  input_right_pads};
274 }
275 
276 } // namespace conv
277 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
CK_TILE_HOST ck_tile::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char *const argv[])
Definition: convolution_parameter.hpp:219
CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
Definition: convolution_parameter.hpp:201
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
int64_t long_index_t
Definition: integer.hpp:11
Definition: convolution_parameter.hpp:15
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
std::size_t GetWeightByte() const
Definition: convolution_parameter.hpp:171
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
ck_tile::long_index_t num_dim_spatial_
Definition: convolution_parameter.hpp:123
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::size_t GetInputByte() const
Definition: convolution_parameter.hpp:159
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::size_t GetFlops() const
Definition: convolution_parameter.hpp:144
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
std::size_t GetByte() const
Definition: convolution_parameter.hpp:194
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > GetOutputSpatialLengths() const
Definition: convolution_parameter.hpp:139
ConvParam(ck_tile::long_index_t n_dim, ck_tile::long_index_t group_count, ck_tile::long_index_t n_batch, ck_tile::long_index_t n_out_channels, ck_tile::long_index_t n_in_channels, const std::vector< ck_tile::long_index_t > &filters_len, const std::vector< ck_tile::long_index_t > &input_len, const std::vector< ck_tile::long_index_t > &strides, const std::vector< ck_tile::long_index_t > &dilations, const std::vector< ck_tile::long_index_t > &left_pads, const std::vector< ck_tile::long_index_t > &right_pads)
Definition: convolution_parameter.hpp:73
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
ConvParam(ck_tile::index_t n_dim, ck_tile::index_t group_count, ck_tile::index_t n_batch, ck_tile::index_t n_out_channels, ck_tile::index_t n_in_channels, const std::vector< ck_tile::index_t > &filters_len, const std::vector< ck_tile::index_t > &input_len, const std::vector< ck_tile::index_t > &strides, const std::vector< ck_tile::index_t > &dilations, const std::vector< ck_tile::index_t > &left_pads, const std::vector< ck_tile::index_t > &right_pads)
Definition: convolution_parameter.hpp:16
std::size_t GetOutputByte() const
Definition: convolution_parameter.hpp:183