/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp Source File
device_grouped_conv_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 namespace tensor_operation {
11 namespace device {
12 
13 // 1d
14 template <typename InLayout, typename WeiLayout, typename OutLayout>
15 constexpr bool is_NWGC_GKXC_NWGK()
16 {
17  return is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
18  is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
19  is_same_v<OutLayout, tensor_layout::convolution::NWGK>;
20 }
21 
22 template <typename InLayout, typename WeiLayout, typename OutLayout>
23 constexpr bool is_GNWC_GKXC_GNWK()
24 {
25  return is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
26  is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
27  is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
28 }
29 
30 template <typename InLayout, typename WeiLayout, typename OutLayout>
31 constexpr bool is_NGCW_GKXC_NGKW()
32 {
33  return is_same_v<InLayout, tensor_layout::convolution::NGCW> &&
34  is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
35  is_same_v<OutLayout, tensor_layout::convolution::NGKW>;
36 }
37 
38 // 2d
39 template <typename InLayout, typename WeiLayout, typename OutLayout>
40 constexpr bool is_NHWGC_GKYXC_NHWGK()
41 {
42  return is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
43  is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
44  is_same_v<OutLayout, tensor_layout::convolution::NHWGK>;
45 }
46 
47 template <typename InLayout, typename WeiLayout, typename OutLayout>
48 constexpr bool is_GNHWC_GKYXC_GNHWK()
49 {
50  return is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
51  is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
52  is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
53 }
54 
55 template <typename InLayout, typename WeiLayout, typename OutLayout>
56 constexpr bool is_NGCHW_GKYXC_NGKHW()
57 {
58  return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
59  is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
60  is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
61 }
62 
63 template <typename InLayout, typename WeiLayout, typename OutLayout>
64 constexpr bool is_NGCHW_GKCYX_NGKHW()
65 {
66  return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
67  is_same_v<WeiLayout, tensor_layout::convolution::GKCYX> &&
68  is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
69 }
70 
71 template <typename InLayout, typename WeiLayout, typename OutLayout>
72 constexpr bool is_NGCHW_NGKHW()
73 {
74  return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
75  is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
76 }
77 
78 // 3d
79 template <typename InLayout, typename WeiLayout, typename OutLayout>
80 constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
81 {
82  return is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
83  is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
84  is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
85 }
86 
87 template <typename InLayout, typename WeiLayout, typename OutLayout>
88 constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
89 {
90  return is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
91  is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
92  is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
93 }
94 
95 template <typename InLayout, typename WeiLayout, typename OutLayout>
96 constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
97 {
98  return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
99  is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
100  is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
101 }
102 
103 template <typename InLayout, typename WeiLayout, typename OutLayout>
104 constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
105 {
106  return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
107  is_same_v<WeiLayout, tensor_layout::convolution::GKCZYX> &&
108  is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
109 }
110 
111 template <typename InLayout, typename WeiLayout, typename OutLayout>
112 constexpr bool is_NGCDHW_NGKDHW()
113 {
114  return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
115  is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
116 }
117 
118 template <typename InLayout, typename WeiLayout, typename OutLayout>
120 {
121  return is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
122  is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
123  is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>();
124 }
125 
126 template <typename InLayout, typename WeiLayout, typename OutLayout>
128 {
129  return is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>() ||
130  is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>() ||
131  is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
132 }
133 
134 template <typename InLayout, typename WeiLayout, typename OutLayout>
136 {
137  return is_NGCW_GKXC_NGKW<InLayout, WeiLayout, OutLayout>() ||
138  is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
139  is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>();
140 }
141 
142 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
143 struct ComputePtrOffsetOfStridedBatch
144 {
145 };
146 
147 template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
148 struct ComputePtrOffsetOfStridedBatch<NumATensor,
149  NumBTensor,
150  NumDTensor,
151  enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
152 {
154 
156  Array<long_index_t, NumBTensor>& BatchStrideBs,
157  Array<long_index_t, NumDTensor>& BatchStrideDs,
158  long_index_t BatchStrideE)
159  : BatchStrideA_(BatchStrideAs),
160  BatchStrideB_(BatchStrideBs),
161  BatchStrideDs_(BatchStrideDs),
162  BatchStrideE_(BatchStrideE)
163  {
164  }
165 
166  __host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const
167  {
170  [&](auto i) { as_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideA_[i]; });
171  return as_offset;
172  }
173 
174  __host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const
175  {
178  [&](auto i) { bs_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideB_[i]; });
179  return bs_offset;
180  }
181 
182  __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
183  {
186  [&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
187  return ds_offset;
188  }
189 
190  [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
191  {
192  return static_cast<long_index_t>(g_idx) * BatchStrideE_;
193  }
194 
195  // alias for kernels without multiple D
196  [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
197  {
198  return static_cast<long_index_t>(g_idx) * BatchStrideE_;
199  }
200 
205  long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
206 };
207 
208 template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
209 struct ComputePtrOffsetOfStridedBatch<NumATensor,
210  NumBTensor,
211  NumDTensor,
212  enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
213 {
215 
217  long_index_t BatchStrideB,
218  Array<long_index_t, NumDTensor> BatchStrideDs,
219  long_index_t BatchStrideE)
220  : BatchStrideA_(BatchStrideA),
221  BatchStrideB_(BatchStrideB),
222  BatchStrideDs_(BatchStrideDs),
223  BatchStrideE_(BatchStrideE)
224  {
225  }
226 
227  __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
228  {
229  return static_cast<long_index_t>(g_idx) * BatchStrideA_;
230  }
231 
232  __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
233  {
234  return static_cast<long_index_t>(g_idx) * BatchStrideB_;
235  }
236 
237  __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
238  {
241  [&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
242  return ds_offset;
243  }
244 
245  [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
246  {
247  return static_cast<long_index_t>(g_idx) * BatchStrideE_;
248  }
249 
250  // alias for kernels without multiple D
251  [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
252  {
253  return static_cast<long_index_t>(g_idx) * BatchStrideE_;
254  }
255 
260  long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
261 };
262 
263 template <bool isTuple, typename Tensors>
264 constexpr static auto GetNumABTensors()
265 {
266  if constexpr(isTuple)
267  {
268  return Number<Tensors::Size()>{};
269  }
270  else
271  {
272  return Number<1>{};
273  }
274 }
275 
276 template <bool isTuple, typename GridwiseGemm, typename DataType>
277 constexpr static auto GetAGridPointer()
278 {
279  if constexpr(isTuple)
280  {
281  return typename GridwiseGemm::AsGridPointer{};
282  }
283  else
284  {
285  return Tuple<const DataType*>{};
286  }
287 }
288 
289 template <bool isTuple, typename GridwiseGemm, typename DataType>
290 constexpr static auto GetBGridPointer()
291 {
292  if constexpr(isTuple)
293  {
294  return typename GridwiseGemm::BsGridPointer{};
295  }
296  else
297  {
298  return Tuple<const DataType*>{};
299  }
300 }
301 
302 template <bool isTuple, typename Id, typename Type>
303 constexpr static auto UnpackDataType()
304 {
305  if constexpr(isTuple)
306  {
307  // unpack if tuple
308  return tuple_element_t<Id{}, Type>{};
309  }
310  else
311  {
312  // if no, return Type
313  return Type{};
314  }
315 }
316 
317 } // namespace device
318 } // namespace tensor_operation
319 } // namespace ck
index_t BatchStrideC_
Definition: device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:56
index_t BatchStrideB_
Definition: device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:55
index_t BatchStrideA_
Definition: device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:54
Array< ck::index_t, NumDTensor > BatchStrideDs_
Definition: device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:84
index_t BatchStrideE_
Definition: device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:85
constexpr bool is_NWGC_GKXC_NWGK()
Definition: device_grouped_conv_utils.hpp:15
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
Definition: device_grouped_conv_utils.hpp:119
constexpr bool is_GNWC_GKXC_GNWK()
Definition: device_grouped_conv_utils.hpp:23
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
Definition: device_grouped_conv_utils.hpp:88
constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial()
Definition: device_grouped_conv_utils.hpp:135
constexpr bool is_NHWGC_GKYXC_NHWGK()
Definition: device_grouped_conv_utils.hpp:40
constexpr bool is_NGCHW_GKYXC_NGKHW()
Definition: device_grouped_conv_utils.hpp:56
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition: device_grouped_conv_utils.hpp:80
constexpr bool is_NGCDHW_NGKDHW()
Definition: device_grouped_conv_utils.hpp:112
constexpr bool is_NGCW_GKXC_NGKW()
Definition: device_grouped_conv_utils.hpp:31
constexpr bool is_NGCHW_GKCYX_NGKHW()
Definition: device_grouped_conv_utils.hpp:64
constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
Definition: device_grouped_conv_utils.hpp:127
constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
Definition: device_grouped_conv_utils.hpp:96
constexpr bool is_GNHWC_GKYXC_GNHWK()
Definition: device_grouped_conv_utils.hpp:48
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
Definition: device_grouped_conv_utils.hpp:104
constexpr bool is_NGCHW_NGKHW()
Definition: device_grouped_conv_utils.hpp:72
Definition: ck.hpp:267
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
int64_t long_index_t
Definition: ck.hpp:299
int32_t index_t
Definition: ck.hpp:298
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
Type
Type of JSON value.
Definition: rapidjson.h:729
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
__host__ constexpr __device__ long_index_t GetEPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:245
__host__ constexpr __device__ long_index_t GetBPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:232
__host__ constexpr __device__ long_index_t GetAPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:227
__host__ constexpr __device__ long_index_t GetCPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:251
ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA, long_index_t BatchStrideB, Array< long_index_t, NumDTensor > BatchStrideDs, long_index_t BatchStrideE)
Definition: device_grouped_conv_utils.hpp:216
__host__ constexpr __device__ auto GetDsPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:237
ComputePtrOffsetOfStridedBatch(Array< long_index_t, NumATensor > &BatchStrideAs, Array< long_index_t, NumBTensor > &BatchStrideBs, Array< long_index_t, NumDTensor > &BatchStrideDs, long_index_t BatchStrideE)
Definition: device_grouped_conv_utils.hpp:155
__host__ constexpr __device__ long_index_t GetEPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:190
__host__ constexpr __device__ long_index_t GetCPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:196
__host__ constexpr __device__ auto GetAsPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:166
__host__ constexpr __device__ auto GetBsPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:174
__host__ constexpr __device__ auto GetDsPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:182