/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp Source File
grouped_convolution_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
18 template <typename InPtr, typename WeiPtr, typename OutPtr, typename CDElementwise>
20 {
23  InPtr in_ptr_,
24  WeiPtr wei_ptr_,
25  const std::vector<const void*> ds_ptr_,
26  OutPtr out_ptr_,
27  index_t k_batch_,
28  CDElementwise elfunc_ = CDElementwise{})
29  : conv::ConvParam(conv_param),
30  in_ptr(in_ptr_),
31  wei_ptr(wei_ptr_),
32  ds_ptr(ds_ptr_),
33  out_ptr(out_ptr_),
34  k_batch(k_batch_),
35  elfunc(elfunc_)
36  {
37  }
38 
39  InPtr in_ptr;
40  WeiPtr wei_ptr;
41  const std::vector<const void*> ds_ptr;
42  OutPtr out_ptr;
44  const CDElementwise elfunc;
45 };
46 
48 
49 template <typename CDElementwise = PassThrough>
55 
56 template <index_t NDimSpatial_,
57  ConvolutionSpecialization ConvSpecialization_,
58  typename InLayout_,
59  typename WeiLayout_,
60  typename DsLayout_,
61  typename OutLayout_,
62  index_t VectorSizeA_ = 1,
63  index_t VectorSizeB_ = 1,
64  index_t VectorSizeC_ = 1,
65  index_t NumGroupsToMerge_ = 1,
66  typename CDElementwise_ = PassThrough>
68 {
69  private:
70  static constexpr auto generate_implicit_gemm_layout()
71  {
72  return generate_tuple([](auto) { return ck_tile::tensor_layout::gemm::RowMajor{}; },
73  number<DsLayout_::size()>{});
74  }
75 
76  public:
77  static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
78  static constexpr index_t NDimSpatial = NDimSpatial_;
79  static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
80  using InLayout = InLayout_;
81  using WeiLayout = WeiLayout_;
82  using DsLayout = DsLayout_;
83  using OutLayout = OutLayout_;
84  using CDElementwise = CDElementwise_;
86  TileGemmTraits<true,
87  true,
88  true,
93  TileGemmTraits<true,
94  true,
95  true,
100  TileGemmTraits<true,
101  true,
102  true,
106  static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
107  static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
108  static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
109  static constexpr index_t NumDTensor = DsLayout::size();
110  using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
111 };
112 
119 {
123 };
124 
149 template <typename TilePartitioner>
151  ck_tile::index_t num_d_pieces,
152  ck_tile::index_t num_h_pieces,
153  ck_tile::index_t num_w_pieces,
154  ck_tile::index_t base_piece_d,
155  ck_tile::index_t base_piece_h,
156  ck_tile::index_t base_piece_w,
157  ck_tile::index_t total_d,
158  ck_tile::index_t total_h,
159  ck_tile::index_t total_w,
162  ck_tile::index_t total_blocks)
163 {
164  // Unflatten piece index into 3D coordinates (W-major, then H, then D)
165  const ck_tile::index_t w_idx = piece_idx % num_w_pieces;
166  const ck_tile::index_t h_idx = (piece_idx / num_w_pieces) % num_h_pieces;
167  const ck_tile::index_t d_idx = piece_idx / (num_w_pieces * num_h_pieces);
168 
169  // Calculate spatial start positions
170  const ck_tile::index_t w_start = w_idx * base_piece_w;
171  const ck_tile::index_t h_start = h_idx * base_piece_h;
172  const ck_tile::index_t d_start = d_idx * base_piece_d;
173 
174  // Calculate piece sizes (last piece may be larger to cover remainder)
175  const ck_tile::index_t w_size =
176  (w_idx == num_w_pieces - 1) ? (total_w - w_start) : base_piece_w;
177  const ck_tile::index_t h_size =
178  (h_idx == num_h_pieces - 1) ? (total_h - h_start) : base_piece_h;
179  const ck_tile::index_t d_size =
180  (d_idx == num_d_pieces - 1) ? (total_d - d_start) : base_piece_d;
181 
182  // Calculate GEMM dimensions for this piece
183  const ck_tile::index_t piece_gemm_m = N * d_size * h_size * w_size;
184  const ck_tile::index_t piece_gemm_n = K;
185 
186  // Calculate GPU grid size for this piece
187  const ck_tile::index_t piece_grid =
188  ((piece_gemm_m + TilePartitioner::MPerBlock - 1) / TilePartitioner::MPerBlock) *
189  ((piece_gemm_n + TilePartitioner::NPerBlock - 1) / TilePartitioner::NPerBlock);
190 
191  return {
192  total_blocks, total_blocks + piece_grid, d_start, h_start, w_start, d_size, h_size, w_size};
193 }
194 
195 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
CK_TILE_HOST SplitImagePieceInfo calculate_spatial_piece(ck_tile::index_t piece_idx, ck_tile::index_t num_d_pieces, ck_tile::index_t num_h_pieces, ck_tile::index_t num_w_pieces, ck_tile::index_t base_piece_d, ck_tile::index_t base_piece_h, ck_tile::index_t base_piece_w, ck_tile::index_t total_d, ck_tile::index_t total_h, ck_tile::index_t total_w, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t total_blocks)
Calculate piece information for split-image convolution.
Definition: grouped_convolution_utils.hpp:150
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:20
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:39
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:42
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:40
index_t k_batch
Definition: grouped_convolution_utils.hpp:43
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:41
const CDElementwise elfunc
Definition: grouped_convolution_utils.hpp:44
CK_TILE_HOST GroupedConvHostArgs()=delete
CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param, InPtr in_ptr_, WeiPtr wei_ptr_, const std::vector< const void * > ds_ptr_, OutPtr out_ptr_, index_t k_batch_, CDElementwise elfunc_=CDElementwise{})
Definition: grouped_convolution_utils.hpp:22
Definition: grouped_convolution_utils.hpp:68
static constexpr index_t NumDTensor
Definition: grouped_convolution_utils.hpp:109
static constexpr ck_tile::index_t VectorSizeB
Definition: grouped_convolution_utils.hpp:107
OutLayout_ OutLayout
Definition: grouped_convolution_utils.hpp:83
static constexpr index_t NumGroupsToMerge
Definition: grouped_convolution_utils.hpp:77
static constexpr ck_tile::index_t VectorSizeC
Definition: grouped_convolution_utils.hpp:108
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_utils.hpp:79
DsLayout_ DsLayout
Definition: grouped_convolution_utils.hpp:82
static constexpr index_t NDimSpatial
Definition: grouped_convolution_utils.hpp:78
decltype(generate_implicit_gemm_layout()) ImplicitGemmDsLayout
Definition: grouped_convolution_utils.hpp:110
WeiLayout_ WeiLayout
Definition: grouped_convolution_utils.hpp:81
InLayout_ InLayout
Definition: grouped_convolution_utils.hpp:80
static constexpr ck_tile::index_t VectorSizeA
Definition: grouped_convolution_utils.hpp:106
CDElementwise_ CDElementwise
Definition: grouped_convolution_utils.hpp:84
Helper struct for split-image piece information.
Definition: grouped_convolution_utils.hpp:119
ck_tile::index_t block_end
GPU block range for this piece.
Definition: grouped_convolution_utils.hpp:120
ck_tile::index_t d_size
Definition: grouped_convolution_utils.hpp:122
ck_tile::index_t d_start
Definition: grouped_convolution_utils.hpp:121
ck_tile::index_t w_start
Spatial start coordinates (output space)
Definition: grouped_convolution_utils.hpp:121
ck_tile::index_t h_size
Definition: grouped_convolution_utils.hpp:122
ck_tile::index_t h_start
Definition: grouped_convolution_utils.hpp:121
ck_tile::index_t w_size
Spatial dimensions of this piece.
Definition: grouped_convolution_utils.hpp:122
ck_tile::index_t block_start
Definition: grouped_convolution_utils.hpp:120
Definition: tile_gemm_traits.hpp:18
Definition: integral_constant.hpp:13
Definition: convolution_parameter.hpp:15
ConvParam(ck_tile::index_t n_dim, ck_tile::index_t group_count, ck_tile::index_t n_batch, ck_tile::index_t n_out_channels, ck_tile::index_t n_in_channels, const std::vector< ck_tile::index_t > &filters_len, const std::vector< ck_tile::index_t > &input_len, const std::vector< ck_tile::index_t > &strides, const std::vector< ck_tile::index_t > &dilations, const std::vector< ck_tile::index_t > &left_pads, const std::vector< ck_tile::index_t > &right_pads)
Definition: convolution_parameter.hpp:16
Definition: unary_element_wise_operation.hpp:437
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17