/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp Source File
grouped_convolution_utils.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
13 {
14  FORWARD,
17 };
18 
25 template <typename InPtr, typename WeiPtr, typename OutPtr, typename CDElementwise>
27 {
30  InPtr in_ptr_,
31  WeiPtr wei_ptr_,
32  const std::vector<const void*> ds_ptr_,
33  OutPtr out_ptr_,
34  index_t k_batch_,
35  CDElementwise elfunc_ = CDElementwise{})
36  : conv::ConvParam(conv_param),
37  in_ptr(in_ptr_),
38  wei_ptr(wei_ptr_),
39  ds_ptr(ds_ptr_),
40  out_ptr(out_ptr_),
41  k_batch(k_batch_),
42  elfunc(elfunc_)
43  {
44  }
45 
46  InPtr in_ptr;
47  WeiPtr wei_ptr;
48  const std::vector<const void*> ds_ptr;
49  OutPtr out_ptr;
51  const CDElementwise elfunc;
52 };
53 
55 
56 template <typename CDElementwise = PassThrough>
62 
63 template <index_t NDimSpatial_,
64  ConvolutionSpecialization ConvSpecialization_,
65  typename InLayout_,
66  typename WeiLayout_,
67  typename DsLayout_,
68  typename OutLayout_,
69  index_t VectorSizeA_ = 1,
70  index_t VectorSizeB_ = 1,
71  index_t VectorSizeC_ = 1,
72  index_t NumGroupsToMerge_ = 1,
73  bool EnableSplitImage_ = false,
74  bool ExplicitGemm_ = false>
76 {
77  private:
78  static constexpr auto generate_implicit_gemm_layout()
79  {
80  return generate_tuple([](auto) { return ck_tile::tensor_layout::gemm::RowMajor{}; },
81  number<DsLayout_::size()>{});
82  }
83 
84  public:
85  // Fixed values for Implicit GEMM
87  {
89  static constexpr ck_tile::index_t TilePartitionerM01 = 4;
90  static constexpr bool kPadM = true;
91  static constexpr bool kPadN = true;
92  static constexpr bool kPadK = true;
93  static constexpr bool TransposeC = false;
94  static constexpr bool FixedVectorSize = true;
95  static constexpr bool UseStructuredSparsity = false;
96  static constexpr bool Persistent = false;
98  };
99  // Compile time parameters
100  static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
101  static constexpr bool EnableSplitImage = EnableSplitImage_;
102  static constexpr bool ExplicitGemm = ExplicitGemm_;
103  static constexpr index_t NDimSpatial = NDimSpatial_;
104  static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
105  using InLayout = InLayout_;
106  using WeiLayout = WeiLayout_;
107  using DsLayout = DsLayout_;
108  using OutLayout = OutLayout_;
109 
110  // Forward Gemm Layouts
114  // Backward Data Gemm Layouts
118  // Backward Weight Gemm Layouts
122 
123  template <GroupedConvDirection Direction>
124  struct GemmLayouts
125  {
126  static_assert(false, "Unsupported direction.");
127  };
128 
129  template <>
131  {
135  };
136 
137  template <>
139  {
143  };
144 
145  template <>
147  {
151  };
152 
153  template <ck_tile::index_t NumWaveGroups = 1>
156  template <ck_tile::index_t NumWaveGroups = 1>
158  true,
159  true,
163  NumWaveGroups>;
164  template <ck_tile::index_t NumWaveGroups = 1>
166  true,
167  true,
171  NumWaveGroups>;
172  static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
173  static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
174  static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
175  static constexpr ck_tile::index_t NumDTensor = DsLayout::size();
176  using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
177 };
178 
185 {
189 };
190 
215 template <typename TilePartitioner>
217  ck_tile::index_t num_d_pieces,
218  ck_tile::index_t num_h_pieces,
219  ck_tile::index_t num_w_pieces,
220  ck_tile::index_t base_piece_d,
221  ck_tile::index_t base_piece_h,
222  ck_tile::index_t base_piece_w,
223  ck_tile::index_t total_d,
224  ck_tile::index_t total_h,
225  ck_tile::index_t total_w,
228  ck_tile::index_t total_blocks)
229 {
230  // Unflatten piece index into 3D coordinates (W-major, then H, then D)
231  const ck_tile::index_t w_idx = piece_idx % num_w_pieces;
232  const ck_tile::index_t h_idx = (piece_idx / num_w_pieces) % num_h_pieces;
233  const ck_tile::index_t d_idx = piece_idx / (num_w_pieces * num_h_pieces);
234 
235  // Calculate spatial start positions
236  const ck_tile::index_t w_start = w_idx * base_piece_w;
237  const ck_tile::index_t h_start = h_idx * base_piece_h;
238  const ck_tile::index_t d_start = d_idx * base_piece_d;
239 
240  // Calculate piece sizes (last piece may be larger to cover remainder)
241  const ck_tile::index_t w_size =
242  (w_idx == num_w_pieces - 1) ? (total_w - w_start) : base_piece_w;
243  const ck_tile::index_t h_size =
244  (h_idx == num_h_pieces - 1) ? (total_h - h_start) : base_piece_h;
245  const ck_tile::index_t d_size =
246  (d_idx == num_d_pieces - 1) ? (total_d - d_start) : base_piece_d;
247 
248  // Calculate GEMM dimensions for this piece
249  const ck_tile::index_t piece_gemm_m = N * d_size * h_size * w_size;
250  const ck_tile::index_t piece_gemm_n = K;
251 
252  // Calculate GPU grid size for this piece
253  const ck_tile::index_t piece_grid =
254  ((piece_gemm_m + TilePartitioner::MPerBlock - 1) / TilePartitioner::MPerBlock) *
255  ((piece_gemm_n + TilePartitioner::NPerBlock - 1) / TilePartitioner::NPerBlock);
256 
257  return {
258  total_blocks, total_blocks + piece_grid, d_start, h_start, w_start, d_size, h_size, w_size};
259 }
260 
261 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
GroupedConvDirection
Definition: grouped_convolution_utils.hpp:13
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:54
CK_TILE_HOST SplitImagePieceInfo calculate_spatial_piece(ck_tile::index_t piece_idx, ck_tile::index_t num_d_pieces, ck_tile::index_t num_h_pieces, ck_tile::index_t num_w_pieces, ck_tile::index_t base_piece_d, ck_tile::index_t base_piece_h, ck_tile::index_t base_piece_w, ck_tile::index_t total_d, ck_tile::index_t total_h, ck_tile::index_t total_w, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t total_blocks)
Calculate piece information for split-image convolution.
Definition: grouped_convolution_utils.hpp:216
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:27
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:46
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:49
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:47
index_t k_batch
Definition: grouped_convolution_utils.hpp:50
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:48
const CDElementwise elfunc
Definition: grouped_convolution_utils.hpp:51
CK_TILE_HOST GroupedConvHostArgs()=delete
CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param, InPtr in_ptr_, WeiPtr wei_ptr_, const std::vector< const void * > ds_ptr_, OutPtr out_ptr_, index_t k_batch_, CDElementwise elfunc_=CDElementwise{})
Definition: grouped_convolution_utils.hpp:29
Definition: grouped_convolution_utils.hpp:87
static constexpr bool Persistent
Definition: grouped_convolution_utils.hpp:96
static constexpr bool UseStructuredSparsity
Definition: grouped_convolution_utils.hpp:95
static constexpr bool kPadM
Definition: grouped_convolution_utils.hpp:90
static constexpr ck_tile::index_t TilePartitionerGroupNum
Definition: grouped_convolution_utils.hpp:88
static constexpr bool kPadK
Definition: grouped_convolution_utils.hpp:92
static constexpr bool kPadN
Definition: grouped_convolution_utils.hpp:91
static constexpr bool FixedVectorSize
Definition: grouped_convolution_utils.hpp:94
static constexpr bool TransposeC
Definition: grouped_convolution_utils.hpp:93
static constexpr ck_tile::index_t TilePartitionerM01
Definition: grouped_convolution_utils.hpp:89
Definition: grouped_convolution_utils.hpp:125
Definition: grouped_convolution_utils.hpp:76
ck_tile::tensor_layout::gemm::RowMajor AsLayoutFwd
Definition: grouped_convolution_utils.hpp:111
OutLayout_ OutLayout
Definition: grouped_convolution_utils.hpp:108
DsLayout_ DsLayout
Definition: grouped_convolution_utils.hpp:107
ck_tile::tensor_layout::gemm::ColumnMajor BsLayoutFwd
Definition: grouped_convolution_utils.hpp:112
static constexpr ck_tile::index_t NumDTensor
Definition: grouped_convolution_utils.hpp:175
static constexpr ck_tile::index_t VectorSizeC
Definition: grouped_convolution_utils.hpp:174
WeiLayout_ WeiLayout
Definition: grouped_convolution_utils.hpp:106
static constexpr index_t NDimSpatial
Definition: grouped_convolution_utils.hpp:103
InLayout_ InLayout
Definition: grouped_convolution_utils.hpp:105
ck_tile::tensor_layout::gemm::RowMajor CLayoutFwd
Definition: grouped_convolution_utils.hpp:113
static constexpr index_t NumGroupsToMerge
Definition: grouped_convolution_utils.hpp:100
ck_tile::tensor_layout::gemm::RowMajor BsLayoutBwdData
Definition: grouped_convolution_utils.hpp:116
ck_tile::tensor_layout::gemm::RowMajor AsLayoutBwdData
Definition: grouped_convolution_utils.hpp:115
static constexpr ck_tile::index_t VectorSizeB
Definition: grouped_convolution_utils.hpp:173
static constexpr bool EnableSplitImage
Definition: grouped_convolution_utils.hpp:101
ck_tile::tensor_layout::gemm::RowMajor CLayoutBwdData
Definition: grouped_convolution_utils.hpp:117
decltype(generate_implicit_gemm_layout()) ImplicitGemmDsLayout
Definition: grouped_convolution_utils.hpp:176
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_utils.hpp:104
static constexpr ck_tile::index_t VectorSizeA
Definition: grouped_convolution_utils.hpp:172
ck_tile::tensor_layout::gemm::RowMajor BsLayoutBwdWeight
Definition: grouped_convolution_utils.hpp:120
ck_tile::tensor_layout::gemm::RowMajor CLayoutBwdWeight
Definition: grouped_convolution_utils.hpp:121
static constexpr bool ExplicitGemm
Definition: grouped_convolution_utils.hpp:102
ck_tile::tensor_layout::gemm::ColumnMajor AsLayoutBwdWeight
Definition: grouped_convolution_utils.hpp:119
Helper struct for split-image piece information.
Definition: grouped_convolution_utils.hpp:185
ck_tile::index_t block_end
GPU block range for this piece.
Definition: grouped_convolution_utils.hpp:186
ck_tile::index_t d_size
Definition: grouped_convolution_utils.hpp:188
ck_tile::index_t d_start
Definition: grouped_convolution_utils.hpp:187
ck_tile::index_t w_start
Spatial start coordinates (output space)
Definition: grouped_convolution_utils.hpp:187
ck_tile::index_t h_size
Definition: grouped_convolution_utils.hpp:188
ck_tile::index_t h_start
Definition: grouped_convolution_utils.hpp:187
ck_tile::index_t w_size
Spatial dimensions of this piece.
Definition: grouped_convolution_utils.hpp:188
ck_tile::index_t block_start
Definition: grouped_convolution_utils.hpp:186
Definition: tile_gemm_traits.hpp:18
Definition: integral_constant.hpp:13
Definition: convolution_parameter.hpp:15
ConvParam(ck_tile::index_t n_dim, ck_tile::index_t group_count, ck_tile::index_t n_batch, ck_tile::index_t n_out_channels, ck_tile::index_t n_in_channels, const std::vector< ck_tile::index_t > &filters_len, const std::vector< ck_tile::index_t > &input_len, const std::vector< ck_tile::index_t > &strides, const std::vector< ck_tile::index_t > &dilations, const std::vector< ck_tile::index_t > &left_pads, const std::vector< ck_tile::index_t > &right_pads)
Definition: convolution_parameter.hpp:16
Definition: unary_element_wise_operation.hpp:449
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17