/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp Source File
device_grouped_conv_bwd_weight.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 
9 
10 namespace ck {
11 namespace tensor_operation {
12 namespace device {
13 
14 #define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
15 
16 template <ck::index_t NDimSpatial,
17  typename InLayout,
18  typename WeiLayout,
19  typename OutLayout,
20  typename InDataType,
21  typename WeiDataType,
22  typename OutDataType,
23  typename InElementwiseOperation,
24  typename WeiElementwiseOperation,
25  typename OutElementwiseOperation,
26  typename ComputeTypeA = InDataType,
27  typename ComputeTypeB = ComputeTypeA>
29 {
30  virtual std::unique_ptr<BaseArgument>
31  MakeArgumentPointer(const void* p_in,
32  void* p_wei,
33  const void* p_out,
34  const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
35  const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
36  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
37  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
38  const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
39  const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
40  const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
41  const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
42  const std::array<ck::index_t, NDimSpatial>& input_left_pads,
43  const std::array<ck::index_t, NDimSpatial>& input_right_pads,
44  InElementwiseOperation in_element_op,
45  WeiElementwiseOperation wei_element_op,
46  OutElementwiseOperation out_element_op,
47  ck::index_t split_k) = 0;
48 
49  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
50 };
51 
52 } // namespace device
53 } // namespace tensor_operation
54 } // namespace ck
Definition: ck.hpp:268
int32_t index_t
Definition: ck.hpp:299
Definition: device_base.hpp:223
Definition: device_grouped_conv_bwd_weight.hpp:29
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in, void *p_wei, const void *p_out, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)=0