/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp Source File
grouped_convolution_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
17 template <typename InPtr, typename WeiPtr, typename OutPtr>
19 {
22  InPtr in_ptr_,
23  WeiPtr wei_ptr_,
24  const std::vector<const void*> ds_ptr_,
25  OutPtr out_ptr_,
26  index_t k_batch_)
27  : conv::ConvParam(conv_param),
28  in_ptr(in_ptr_),
29  wei_ptr(wei_ptr_),
30  ds_ptr(ds_ptr_),
31  out_ptr(out_ptr_),
32  k_batch(k_batch_)
33  {
34  }
35 
36  InPtr in_ptr;
37  WeiPtr wei_ptr;
38  const std::vector<const void*> ds_ptr;
39  OutPtr out_ptr;
41 };
42 
46 
47 template <index_t NDimSpatial_,
48  ConvolutionSpecialization ConvSpecialization_,
49  typename InLayout_,
50  typename WeiLayout_,
51  typename DsLayout_,
52  typename OutLayout_,
53  index_t VectorSizeA_ = 1,
54  index_t VectorSizeB_ = 1,
55  index_t VectorSizeC_ = 1>
57 {
58  private:
59  static constexpr auto generate_implicit_gemm_layout()
60  {
61  return generate_tuple([](auto) { return ck_tile::tensor_layout::gemm::RowMajor{}; },
62  number<DsLayout_::size()>{});
63  }
64 
65  public:
66  static constexpr index_t NumGroupsToMerge = 1;
67  static constexpr index_t NDimSpatial = NDimSpatial_;
68  static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
69  using InLayout = InLayout_;
70  using WeiLayout = WeiLayout_;
71  using DsLayout = DsLayout_;
72  using OutLayout = OutLayout_;
74  TileGemmTraits<true,
75  true,
76  true,
81  TileGemmTraits<true,
82  true,
83  true,
86  // TODO: Change to and enable vector load
87  // ck_tile::tensor_layout::gemm::RowMajor,
88  // ck_tile::tensor_layout::gemm::RowMajor,
91  TileGemmTraits<true,
92  true,
93  true,
96  // TODO: Change to and enable vector load
97  // ck_tile::tensor_layout::gemm::ColumnMajor,
98  // ck_tile::tensor_layout::gemm::RowMajor,
100  static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
101  static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
102  static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
103  static constexpr index_t NumDTensor = DsLayout::size();
104  using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
105 };
106 
107 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:19
index_t k_batch
Definition: grouped_convolution_utils.hpp:40
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:36
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:37
CK_TILE_HOST GroupedConvHostArgs()=delete
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:39
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:38
CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param, InPtr in_ptr_, WeiPtr wei_ptr_, const std::vector< const void * > ds_ptr_, OutPtr out_ptr_, index_t k_batch_)
Definition: grouped_convolution_utils.hpp:21
Definition: grouped_convolution_utils.hpp:57
InLayout_ InLayout
Definition: grouped_convolution_utils.hpp:69
static constexpr index_t NDimSpatial
Definition: grouped_convolution_utils.hpp:67
WeiLayout_ WeiLayout
Definition: grouped_convolution_utils.hpp:70
static constexpr ck_tile::index_t VectorSizeA
Definition: grouped_convolution_utils.hpp:100
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_utils.hpp:68
decltype(generate_implicit_gemm_layout()) ImplicitGemmDsLayout
Definition: grouped_convolution_utils.hpp:104
static constexpr ck_tile::index_t VectorSizeC
Definition: grouped_convolution_utils.hpp:102
static constexpr index_t NumDTensor
Definition: grouped_convolution_utils.hpp:103
static constexpr index_t NumGroupsToMerge
Definition: grouped_convolution_utils.hpp:66
DsLayout_ DsLayout
Definition: grouped_convolution_utils.hpp:71
OutLayout_ OutLayout
Definition: grouped_convolution_utils.hpp:72
static constexpr ck_tile::index_t VectorSizeB
Definition: grouped_convolution_utils.hpp:101
Definition: tile_gemm_traits.hpp:18
Definition: integral_constant.hpp:13
Definition: convolution_parameter.hpp:15
ConvParam(ck_tile::index_t n_dim, ck_tile::index_t group_count, ck_tile::index_t n_batch, ck_tile::index_t n_out_channels, ck_tile::index_t n_in_channels, const std::vector< ck_tile::index_t > &filters_len, const std::vector< ck_tile::index_t > &input_len, const std::vector< ck_tile::index_t > &strides, const std::vector< ck_tile::index_t > &dilations, const std::vector< ck_tile::index_t > &left_pads, const std::vector< ck_tile::index_t > &right_pads)
Definition: convolution_parameter.hpp:16
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17