/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/convolution_parameter.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/convolution_parameter.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/convolution_parameter.hpp Source File
convolution_parameter.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <numeric>
8 #include <iterator>
9 #include <vector>
10 
11 #include "ck/ck.hpp"
12 
14 
15 namespace ck {
16 namespace utils {
17 namespace conv {
18 
19 struct ConvParam
20 {
23  ck::index_t group_count,
24  ck::index_t n_batch,
25  ck::index_t n_out_channels,
26  ck::index_t n_in_channels,
27  const std::vector<ck::index_t>& filters_len,
28  const std::vector<ck::index_t>& input_len,
29  const std::vector<ck::index_t>& strides,
30  const std::vector<ck::index_t>& dilations,
31  const std::vector<ck::index_t>& left_pads,
32  const std::vector<ck::index_t>& right_pads);
33 
35  ck::long_index_t group_count,
36  ck::long_index_t n_batch,
37  ck::long_index_t n_out_channels,
38  ck::long_index_t n_in_channels,
39  const std::vector<ck::long_index_t>& filters_len,
40  const std::vector<ck::long_index_t>& input_len,
41  const std::vector<ck::long_index_t>& strides,
42  const std::vector<ck::long_index_t>& dilations,
43  const std::vector<ck::long_index_t>& left_pads,
44  const std::vector<ck::long_index_t>& right_pads);
45 
51 
52  std::vector<ck::long_index_t> filter_spatial_lengths_;
53  std::vector<ck::long_index_t> input_spatial_lengths_;
54  std::vector<ck::long_index_t> output_spatial_lengths_;
55 
56  std::vector<ck::long_index_t> conv_filter_strides_;
57  std::vector<ck::long_index_t> conv_filter_dilations_;
58 
59  std::vector<ck::long_index_t> input_left_pads_;
60  std::vector<ck::long_index_t> input_right_pads_;
61 
62  std::vector<ck::long_index_t> GetOutputSpatialLengths() const;
63 
64  std::size_t GetFlops() const;
65 
66  template <typename InDataType>
67  std::size_t GetInputByte() const
68  {
69  // sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
70  return sizeof(InDataType) *
71  (G_ * N_ * C_ *
72  ck::accumulate_n<std::size_t>(
73  std::begin(input_spatial_lengths_), num_dim_spatial_, 1, std::multiplies<>()));
74  }
75 
76  template <typename WeiDataType>
77  std::size_t GetWeightByte() const
78  {
79  // sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
80  return sizeof(WeiDataType) *
81  (G_ * K_ * C_ *
82  ck::accumulate_n<std::size_t>(
83  std::begin(filter_spatial_lengths_), num_dim_spatial_, 1, std::multiplies<>()));
84  }
85 
86  template <typename OutDataType>
87  std::size_t GetOutputByte() const
88  {
89  // sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
90  return sizeof(OutDataType) * (G_ * N_ * K_ *
91  std::accumulate(std::begin(output_spatial_lengths_),
92  std::end(output_spatial_lengths_),
93  static_cast<std::size_t>(1),
94  std::multiplies<std::size_t>()));
95  }
96 
97  template <typename InDataType, typename WeiDataType, typename OutDataType>
98  std::size_t GetByte() const
99  {
100  return GetInputByte<InDataType>() + GetWeightByte<WeiDataType>() +
101  GetOutputByte<OutDataType>();
102  }
103 };
104 
106 
107 ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]);
108 
109 } // namespace conv
110 } // namespace utils
111 } // namespace ck
112 
113 std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p);
std::ostream & operator<<(std::ostream &os, const ck::utils::conv::ConvParam &p)
ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char *const argv[])
std::string get_conv_param_parser_helper_msg()
Definition: ck.hpp:267
int64_t long_index_t
Definition: ck.hpp:299
int32_t index_t
Definition: ck.hpp:298
Definition: convolution_parameter.hpp:20
std::size_t GetFlops() const
ck::long_index_t C_
Definition: convolution_parameter.hpp:50
std::vector< ck::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:60
std::vector< ck::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:59
ConvParam(ck::long_index_t n_dim, ck::long_index_t group_count, ck::long_index_t n_batch, ck::long_index_t n_out_channels, ck::long_index_t n_in_channels, const std::vector< ck::long_index_t > &filters_len, const std::vector< ck::long_index_t > &input_len, const std::vector< ck::long_index_t > &strides, const std::vector< ck::long_index_t > &dilations, const std::vector< ck::long_index_t > &left_pads, const std::vector< ck::long_index_t > &right_pads)
std::vector< ck::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:57
ck::long_index_t num_dim_spatial_
Definition: convolution_parameter.hpp:46
std::vector< ck::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:53
std::vector< ck::long_index_t > GetOutputSpatialLengths() const
ConvParam(ck::index_t n_dim, ck::index_t group_count, ck::index_t n_batch, ck::index_t n_out_channels, ck::index_t n_in_channels, const std::vector< ck::index_t > &filters_len, const std::vector< ck::index_t > &input_len, const std::vector< ck::index_t > &strides, const std::vector< ck::index_t > &dilations, const std::vector< ck::index_t > &left_pads, const std::vector< ck::index_t > &right_pads)
std::size_t GetByte() const
Definition: convolution_parameter.hpp:98
std::vector< ck::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:54
std::vector< ck::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:56
std::size_t GetInputByte() const
Definition: convolution_parameter.hpp:67
ck::long_index_t N_
Definition: convolution_parameter.hpp:48
ck::long_index_t G_
Definition: convolution_parameter.hpp:47
std::size_t GetWeightByte() const
Definition: convolution_parameter.hpp:77
ck::long_index_t K_
Definition: convolution_parameter.hpp:49
std::vector< ck::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:52
std::size_t GetOutputByte() const
Definition: convolution_parameter.hpp:87