11 #include <hip/hip_runtime.h>
32 const WeiDataType* __restrict__ p_wei,
33 OutDataType* __restrict__ p_out,
40 const std::array<ck_tile::long_index_t, NDimSpatial>& in_spatial_lengths,
42 const std::array<ck_tile::long_index_t, NDimSpatial>& wei_spatial_lengths,
44 const std::array<ck_tile::long_index_t, NDimSpatial>& out_spatial_lengths,
46 const std::array<ck_tile::long_index_t, NDimSpatial>& conv_strides,
47 const std::array<ck_tile::long_index_t, NDimSpatial>& conv_dilations,
48 const std::array<ck_tile::long_index_t, NDimSpatial>& in_left_pads)
const
57 output_length *= out_spatial_lengths[i];
61 std::array<ck_tile::long_index_t, NDimSpatial + 3> out_strides;
63 out_strides[NDimSpatial + 2] = stride;
65 out_strides[NDimSpatial + 1] = stride;
69 out_strides[i + 1] = stride;
70 stride *= out_spatial_lengths[i];
72 out_strides[0] = stride;
75 std::array<ck_tile::long_index_t, NDimSpatial + 3> in_strides;
77 in_strides[NDimSpatial + 2] = stride;
79 in_strides[NDimSpatial + 1] = stride;
83 in_strides[i + 1] = stride;
84 stride *= in_spatial_lengths[i];
86 in_strides[0] = stride;
89 std::array<ck_tile::long_index_t, NDimSpatial + 3> wei_strides;
91 wei_strides[NDimSpatial + 2] = stride;
95 wei_strides[i + 2] = stride;
96 stride *= wei_spatial_lengths[i];
98 wei_strides[1] = stride;
100 wei_strides[0] = stride;
110 tmp -= n * out_strides[0];
116 out_spatial_idx[i] = tmp / out_strides[i + 1];
117 tmp -= out_spatial_idx[i] * out_strides[i + 1];
122 tmp -= g * out_strides[NDimSpatial + 1];
134 if constexpr(NDimSpatial == 1)
146 if(wi >= 0 && wi < in_spatial_lengths[0])
148 std::array<ck_tile::index_t, 1> in_spatial = {
static_cast<index_t>(wi)};
149 std::array<ck_tile::index_t, 1> wei_spatial = {x};
151 detail::calculate_input_index<1>(n, g, c, in_spatial, in_strides);
153 g, k, c, wei_spatial, wei_strides);
155 v_acc += type_convert<float>(p_in[in_idx]) *
156 type_convert<float>(p_wei[wei_idx]);
160 else if constexpr(NDimSpatial == 2)
179 if(hi >= 0 && hi < in_spatial_lengths[0] && wi >= 0 &&
180 wi < in_spatial_lengths[1])
182 std::array<ck_tile::index_t, 2> in_spatial = {
184 std::array<ck_tile::index_t, 2> wei_spatial = {y, x};
186 n, g, c, in_spatial, in_strides);
188 g, k, c, wei_spatial, wei_strides);
190 v_acc += type_convert<float>(p_in[in_idx]) *
191 type_convert<float>(p_wei[wei_idx]);
196 else if constexpr(NDimSpatial == 3)
223 if(di >= 0 && di < in_spatial_lengths[0] && hi >= 0 &&
224 hi < in_spatial_lengths[1] && wi >= 0 &&
225 wi < in_spatial_lengths[2])
227 std::array<ck_tile::index_t, 3> in_spatial = {
231 std::array<ck_tile::index_t, 3> wei_spatial = {z, y, x};
233 n, g, c, in_spatial, in_strides);
235 detail::calculate_weight_index<3>(
236 g, k, c, wei_spatial, wei_strides);
238 v_acc += type_convert<float>(p_in[in_idx]) *
239 type_convert<float>(p_wei[wei_idx]);
248 p_out[ii] = type_convert<OutDataType>(v_acc);
256 typename WeiDataType,
257 typename OutDataType>
259 const WeiDataType* p_wei_dev,
260 OutDataType* p_out_dev,
265 std::vector<ck_tile::long_index_t> in_spatial_lengths,
266 std::vector<ck_tile::long_index_t> wei_spatial_lengths,
267 std::vector<ck_tile::long_index_t> out_spatial_lengths,
268 std::vector<ck_tile::long_index_t> conv_strides,
269 std::vector<ck_tile::long_index_t> conv_dilations,
270 std::vector<ck_tile::long_index_t> in_left_pads,
274 auto in_spatial_arr = to_array_with_default<NDimSpatial>(in_spatial_lengths);
275 auto wei_spatial_arr = to_array_with_default<NDimSpatial>(wei_spatial_lengths);
276 auto out_spatial_arr = to_array_with_default<NDimSpatial>(out_spatial_lengths);
277 auto conv_strides_arr = to_array_with_default<NDimSpatial>(conv_strides);
278 auto conv_dilations_arr = to_array_with_default<NDimSpatial>(conv_dilations);
279 auto in_left_pads_arr = to_array_with_default<NDimSpatial>(in_left_pads, 0);
285 output_length *= out_spatial_lengths[i];
289 naive_grouped_conv_fwd_kernel<NDimSpatial, InDataType, WeiDataType, OutDataType>;
292 const ck_tile::index_t grid_size = (output_length + block_size - 1) / block_size;
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_HOST auto make_kernel(KernelImpl, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:60
CK_TILE_HOST float naive_grouped_conv_fwd(const InDataType *p_in_dev, const WeiDataType *p_wei_dev, OutDataType *p_out_dev, ck_tile::index_t G, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t C, std::vector< ck_tile::long_index_t > in_spatial_lengths, std::vector< ck_tile::long_index_t > wei_spatial_lengths, std::vector< ck_tile::long_index_t > out_spatial_lengths, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, ck_tile::stream_config stream_config={})
Definition: naive_grouped_conv_fwd_gpu.hpp:258
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition: kernel_launch.hpp:173
Definition: naive_grouped_conv_fwd_gpu.hpp:27
__device__ void operator()(const InDataType *__restrict__ p_in, const WeiDataType *__restrict__ p_wei, OutDataType *__restrict__ p_out, ck_tile::index_t G, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t C, const std::array< ck_tile::long_index_t, NDimSpatial > &in_spatial_lengths, const std::array< ck_tile::long_index_t, NDimSpatial > &wei_spatial_lengths, const std::array< ck_tile::long_index_t, NDimSpatial > &out_spatial_lengths, const std::array< ck_tile::long_index_t, NDimSpatial > &conv_strides, const std::array< ck_tile::long_index_t, NDimSpatial > &conv_dilations, const std::array< ck_tile::long_index_t, NDimSpatial > &in_left_pads) const
Definition: naive_grouped_conv_fwd_gpu.hpp:31
static constexpr ck_tile::index_t kBlockSize
Definition: naive_grouped_conv_fwd_gpu.hpp:28
Definition: stream_config.hpp:30