/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp Source File
pool_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include "ck_tile/ops/common.hpp"
9 #include <type_traits>
10 
11 namespace ck_tile {
12 
14 template <typename TensorShape, typename WindowShape>
16 {
17 
18  CK_TILE_HOST PoolHostArgs(const void* input_ptr_,
19  void* output_ptr_,
20  void* output_index_ptr_,
21  TensorShape input_shape_,
22  TensorShape output_shape_,
23  TensorShape input_strides_,
24  TensorShape output_strides_,
25  WindowShape window_lengths_,
26  WindowShape window_strides_,
27  WindowShape window_dilations_,
28  WindowShape input_left_pads_,
29  WindowShape input_right_pads_)
30  : input_ptr(input_ptr_),
31  output_ptr(output_ptr_),
32  output_index_ptr(output_index_ptr_),
33  input_shape(input_shape_),
34  output_shape(output_shape_),
35  input_strides(input_strides_),
36  output_strides(output_strides_),
37  window_lengths(window_lengths_),
38  window_strides(window_strides_),
39  window_dilations(window_dilations_),
40  input_left_pads(input_left_pads_),
41  input_right_pads(input_right_pads_)
42  {
43  }
44 
45  const void* input_ptr;
46  void* output_ptr;
48 
49  TensorShape input_shape;
50  TensorShape output_shape;
51  TensorShape input_strides;
52  TensorShape output_strides;
53  WindowShape window_lengths;
54  WindowShape window_strides;
55  WindowShape window_dilations;
56  WindowShape input_left_pads;
57  WindowShape input_right_pads;
58 };
59 
61 template <typename TensorShape, typename WindowShape>
63 {
64  const void* input_ptr;
65  void* output_ptr;
67  TensorShape input_shape;
68  TensorShape output_shape;
69  TensorShape input_strides;
70  TensorShape output_strides;
71  WindowShape window_lengths;
72  WindowShape window_strides;
73  WindowShape window_dilations;
74  WindowShape input_left_pads;
75  WindowShape input_right_pads;
76 };
77 
78 template <typename Problem_, typename Policy_ = PoolDefaultPolicy>
79 struct PoolKernel
80 {
83 
88 
89  static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
90 
91  CK_TILE_HOST static constexpr auto BlockSize()
92  {
93  return is_wave32() ? kBlockSize / 2 : kBlockSize;
94  }
95 
96  template <typename TensorShape, typename WindowShape>
98  {
99  using S = typename Problem::BlockShape;
100 
101  // Compile-time validation for 2D pooling
102  static_assert(TensorShape::size() == 4, "2D pooling requires 4D input tensor (N,H,W,C)");
103  static_assert(WindowShape::size() == 2, "2D pooling requires 2D window shape (Y,X)");
104 
105  // Extract dimension values
106  const index_t N = kargs.input_shape.at(number<0>{});
107  const index_t H = kargs.input_shape.at(number<1>{});
108  const index_t W = kargs.input_shape.at(number<2>{});
109  const index_t C = kargs.input_shape.at(number<3>{});
110 
111  const index_t No = kargs.output_shape.at(number<0>{});
112  const index_t Ho = kargs.output_shape.at(number<1>{});
113  const index_t Wo = kargs.output_shape.at(number<2>{});
114  const index_t Co = kargs.output_shape.at(number<3>{});
115 
116  const index_t Y = kargs.window_lengths.at(number<0>{});
117  const index_t X = kargs.window_lengths.at(number<1>{});
118 
119  const index_t WindowStrideH = kargs.window_strides.at(number<0>{});
120  const index_t WindowStrideW = kargs.window_strides.at(number<1>{});
121 
122  const index_t WindowDilationH = kargs.window_dilations.at(number<0>{});
123  const index_t WindowDilationW = kargs.window_dilations.at(number<1>{});
124 
125  const index_t InLeftPadH = kargs.input_left_pads.at(number<0>{});
126  const index_t InLeftPadW = kargs.input_left_pads.at(number<1>{});
127 
128  const index_t InRightPadH = kargs.input_right_pads.at(number<0>{});
129  const index_t InRightPadW = kargs.input_right_pads.at(number<1>{});
130 
131  const index_t MRaw = N * Ho * Wo * C;
132  const index_t KRaw = Y * X;
133  const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw;
134  const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw;
135 
136  auto reduce_op = typename Problem::ReduceOp{};
137 
138  // Create input descriptor with all transformations
139  auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides);
140 
141  // Apply spatial padding to input descriptor
142  const auto padded_in_desc = transform_tensor_descriptor(
143  in_desc,
145  make_pad_transform(H, InLeftPadH, InRightPadH),
146  make_pad_transform(W, InLeftPadW, InRightPadW),
150 
151  // Create sliding windows by embedding pooling windows into descriptor
152  const auto embed_in_desc = transform_tensor_descriptor(
153  padded_in_desc,
154  make_tuple(
156  make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
157  make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
161 
162  // Reshape into 2D matrix: output positions (M) x pooling window elements (K)
163  const auto merged_embed_in_desc =
164  transform_tensor_descriptor(embed_in_desc,
169 
170  const auto in_desc_padded = transform_tensor_descriptor(
171  merged_embed_in_desc,
175 
176  // Create output descriptor with transformations
177  auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides);
178 
179  const auto merged_out_desc = transform_tensor_descriptor(
180  out_desc,
181  make_tuple(make_merge_transform(make_tuple(No, Ho, Wo, Co))),
184 
185  const auto out_desc_padded =
186  transform_tensor_descriptor(merged_out_desc,
190 
191  // Now create buffer views and tensor views with the fully transformed descriptors
192  const InDataType in_identity =
193  type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
194  const OutDataType out_identity =
195  type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
196 
197  auto in_buffer_view = make_buffer_view<address_space_enum::global>(
198  static_cast<const InDataType*>(kargs.input_ptr),
199  in_desc.get_element_space_size(),
200  in_identity);
201  const auto in_tensor_padded =
202  tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
203  in_desc_padded};
204 
205  auto out_buffer_view = make_buffer_view<address_space_enum::global>(
206  static_cast<OutDataType*>(kargs.output_ptr),
207  out_desc.get_element_space_size(),
208  out_identity);
209  const auto out_tensor_padded =
210  tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
211  out_desc_padded};
212 
213  if constexpr(Problem::kOutputIndex)
214  {
215  auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
216  static_cast<IndexDataType*>(kargs.output_index_ptr),
217  out_desc.get_element_space_size(),
218  IndexDataType(-1));
219  const auto out_index_tensor_padded =
220  tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
221  out_index_buffer_view, out_desc_padded};
222 
223  return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
224  }
225  else
226  {
227  // Return a dummy tensor for the third element when index output is not needed
228  return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
229  }
230  }
231 
232  template <typename TensorShape, typename WindowShape>
234  {
235  using S = typename Problem::BlockShape;
236 
237  // Compile-time validation for 3D pooling
238  static_assert(TensorShape::size() == 5, "3D pooling requires 5D input tensor (N,D,H,W,C)");
239  static_assert(WindowShape::size() == 3, "3D pooling requires 3D window shape (Z,Y,X)");
240 
241  // Extract dimension values
242  const index_t N = kargs.input_shape.at(number<0>{});
243  const index_t D = kargs.input_shape.at(number<1>{});
244  const index_t H = kargs.input_shape.at(number<2>{});
245  const index_t W = kargs.input_shape.at(number<3>{});
246  const index_t C = kargs.input_shape.at(number<4>{});
247 
248  const index_t No = kargs.output_shape.at(number<0>{});
249  const index_t Do = kargs.output_shape.at(number<1>{});
250  const index_t Ho = kargs.output_shape.at(number<2>{});
251  const index_t Wo = kargs.output_shape.at(number<3>{});
252  const index_t Co = kargs.output_shape.at(number<4>{});
253 
254  const index_t Z = kargs.window_lengths.at(number<0>{});
255  const index_t Y = kargs.window_lengths.at(number<1>{});
256  const index_t X = kargs.window_lengths.at(number<2>{});
257 
258  const index_t WindowStrideD = kargs.window_strides.at(number<0>{});
259  const index_t WindowStrideH = kargs.window_strides.at(number<1>{});
260  const index_t WindowStrideW = kargs.window_strides.at(number<2>{});
261 
262  const index_t WindowDilationD = kargs.window_dilations.at(number<0>{});
263  const index_t WindowDilationH = kargs.window_dilations.at(number<1>{});
264  const index_t WindowDilationW = kargs.window_dilations.at(number<2>{});
265 
266  const index_t InLeftPadD = kargs.input_left_pads.at(number<0>{});
267  const index_t InLeftPadH = kargs.input_left_pads.at(number<1>{});
268  const index_t InLeftPadW = kargs.input_left_pads.at(number<2>{});
269 
270  const index_t InRightPadD = kargs.input_right_pads.at(number<0>{});
271  const index_t InRightPadH = kargs.input_right_pads.at(number<1>{});
272  const index_t InRightPadW = kargs.input_right_pads.at(number<2>{});
273 
274  const index_t MRaw = N * Do * Ho * Wo * C;
275  const index_t KRaw = Z * Y * X;
276  const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw;
277  const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw;
278 
279  auto reduce_op = typename Problem::ReduceOp{};
280 
281  // Create input descriptor with all transformations
282  auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides);
283 
284  // Apply spatial padding to input descriptor (all 3D dimensions)
285  const auto padded_in_desc = transform_tensor_descriptor(
286  in_desc,
288  make_pad_transform(D, InLeftPadD, InRightPadD),
289  make_pad_transform(H, InLeftPadH, InRightPadH),
290  make_pad_transform(W, InLeftPadW, InRightPadW),
294 
295  // Create 3D sliding windows by embedding pooling windows into descriptor
296  const auto embed_in_desc = transform_tensor_descriptor(
297  padded_in_desc,
298  make_tuple(
300  make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)),
301  make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
302  make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
306  sequence<1, 2>{},
307  sequence<3, 4>{},
308  sequence<5, 6>{},
309  sequence<7>{}));
310 
311  // Reshape into 2D matrix: output positions (M) x pooling window elements (K)
312  const auto merged_embed_in_desc = transform_tensor_descriptor(
313  embed_in_desc,
314  make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C)),
315  make_merge_transform(make_tuple(Z, Y, X))),
318 
319  const auto in_desc_padded = transform_tensor_descriptor(
320  merged_embed_in_desc,
324 
325  // Create output descriptor with transformations
326  auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides);
327 
328  const auto merged_out_desc = transform_tensor_descriptor(
329  out_desc,
330  make_tuple(make_merge_transform(make_tuple(No, Do, Ho, Wo, Co))),
333 
334  const auto out_desc_padded =
335  transform_tensor_descriptor(merged_out_desc,
339 
340  // Now create buffer views and tensor views with the fully transformed descriptors
341  const InDataType in_identity =
342  type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
343  const OutDataType out_identity =
344  type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
345 
346  auto in_buffer_view = make_buffer_view<address_space_enum::global>(
347  static_cast<const InDataType*>(kargs.input_ptr),
348  in_desc.get_element_space_size(),
349  in_identity);
350  const auto in_tensor_padded =
351  tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
352  in_desc_padded};
353 
354  auto out_buffer_view = make_buffer_view<address_space_enum::global>(
355  static_cast<OutDataType*>(kargs.output_ptr),
356  out_desc.get_element_space_size(),
357  out_identity);
358  const auto out_tensor_padded =
359  tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
360  out_desc_padded};
361 
362  if constexpr(Problem::kOutputIndex)
363  {
364  auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
365  static_cast<IndexDataType*>(kargs.output_index_ptr),
366  out_desc.get_element_space_size(),
367  IndexDataType(-1));
368  const auto out_index_tensor_padded =
369  tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
370  out_index_buffer_view, out_desc_padded};
371 
372  return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
373  }
374  else
375  {
376  // Return a dummy tensor for the third element when index output is not needed
377  return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
378  }
379  }
380 
381  public:
382  template <typename TensorShape, typename WindowShape>
384  {
385  using S = typename Problem::BlockShape;
386 
387  // Compile-time validation for supported window dimensions
388  static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
389  "Only 2D and 3D pooling operations are supported");
390 
391  const auto iM = get_block_id() * S::Block_M;
392 
393  // Get tensors based on dimensionality
394  auto [in_tensor_padded, out_tensor_padded, out_index_tensor_padded] = [&]() {
395  if constexpr(WindowShape::size() == 2)
396  return MakeTensorView2D(kargs);
397  else if constexpr(WindowShape::size() == 3)
398  return MakeTensorView3D(kargs);
399  else
400  static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
401  "Unsupported WindowShape rank: only 2D or 3D pooling is supported");
402  }();
403 
404  auto reduce_op = typename Problem::ReduceOp{};
405 
406  auto x_window = make_tile_window(in_tensor_padded,
408  {iM, 0},
409  Policy::template MakeXBlockTileDistribution<Problem>());
410  auto y_window = make_tile_window(out_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
411 
412  __shared__ char smem[Policy::template GetSmemSize<Problem>()];
413 
414  const auto reduce_len =
415  in_tensor_padded.get_tensor_descriptor().get_lengths().at(number<1>{});
416  index_t num_k_tiles =
417  __builtin_amdgcn_readfirstlane(integer_divide_ceil(reduce_len, S::Block_N));
418 
419  auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
420  auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
421  auto block_reduce2d_cross_warp = Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
422 
423  using XTensorTile = decltype(load_tile(x_window));
424  auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
425  set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
426 
427  if constexpr(Problem::kOutputIndex)
428  {
429  auto y_index_window =
430  make_tile_window(out_index_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
431 
432  auto y_index_tile =
433  block_reduce2d.template MakeYIndexBlockTile<XTensorTile, IndexDataType>();
434  set_tile(y_index_tile, IndexDataType(0));
435 
436  // Main reduction loop - with index tracking
437  for(int k_tile = amd_wave_read_first_lane(0); k_tile < num_k_tiles; ++k_tile)
438  {
439  const auto x_tile = load_tile(x_window);
440  auto index_calculator = [&](const auto& x_indices) {
441  // Get global coordinates in the 2D matrix space (M, N)
442  const auto global_M = x_indices.at(number<0>{}) + iM;
443  const auto global_N = (k_tile * S::Block_N) + x_indices.at(number<1>{});
444  return in_tensor_padded.get_tensor_descriptor().calculate_offset(
445  make_tuple(global_M, global_N));
446  };
447 
448  block_reduce2d(x_tile, y_tile, y_index_tile, reduce_op, index_calculator);
449  move_tile_window(x_window, {0, S::Block_N});
450  }
451 
452  block_reduce2d_sync(y_tile, y_index_tile, reduce_op);
453  if constexpr(Problem::kNeedCrossWarpSync)
454  {
455  __shared__ char smem_indices[Policy::template GetIndicesSmemSize<Problem>()];
456 
457  block_reduce2d_cross_warp(y_tile, y_index_tile, smem, smem_indices, reduce_op);
458  }
459 
460  store_tile(y_window, cast_tile<OutDataType>(y_tile));
461  store_tile(y_index_window, cast_tile<IndexDataType>(y_index_tile));
462  }
463  else
464  {
465  // Main reduction loop - without index tracking
466  for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
467  {
468  const auto x_tile = load_tile(x_window);
469  block_reduce2d(x_tile, y_tile, reduce_op);
470  move_tile_window(x_window, {0, S::Block_N});
471  }
472 
473  block_reduce2d_sync(y_tile, reduce_op);
474  block_reduce2d_cross_warp(y_tile, smem, reduce_op);
475 
476  store_tile(y_window, cast_tile<OutDataType>(y_tile));
477  }
478  }
479 
490  template <typename TensorShape, typename WindowShape>
492  {
493  constexpr index_t InputRank = TensorShape::size();
494  constexpr index_t OutputRank = TensorShape::size(); // Same as input rank
495  constexpr index_t WindowRank = WindowShape::size();
496 
497  // Validate window dimensions (only 2D and 3D supported)
498  if constexpr(WindowRank != 2 && WindowRank != 3)
499  {
500  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
501  {
502  CK_TILE_ERROR("Only 2D and 3D pooling are supported!");
503  }
504  return false;
505  }
506 
507  // Validate that input rank matches expected rank for window dimensions
508  if constexpr((WindowRank == 2 && InputRank != 4) || (WindowRank == 3 && InputRank != 5))
509  {
510  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
511  {
512  CK_TILE_ERROR("Input tensor rank doesn't match window dimensions!");
513  }
514  return false;
515  }
516 
517  // Check that channel dimension (last dimension) is contiguous for both input and output
518  if(kargs.input_strides.at(number<InputRank - 1>{}) != 1)
519  {
520  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
521  {
522  CK_TILE_ERROR("Input tensor's channel dimension must have stride 1!");
523  }
524  return false;
525  }
526 
527  if(kargs.output_strides.at(number<OutputRank - 1>{}) != 1)
528  {
529  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
530  {
531  CK_TILE_ERROR("Output tensor's channel dimension must have stride 1!");
532  }
533  return false;
534  }
535 
536  return true;
537  }
538 
541  template <typename TensorShape, typename WindowShape>
542  CK_TILE_HOST static constexpr index_t
544  {
545  using S = typename Problem::BlockShape;
546 
547  // Calculate total output elements (M dimension)
548  index_t M = 1;
549  static_for<0, TensorShape::size(), 1>{}([&](auto i) { M *= kargs.output_shape.at(i); });
550 
551  // Calculate grid size: ceil(M / Block_M)
552  return (M + S::Block_M - 1) / S::Block_M;
553  }
554 
556  template <typename TensorShape, typename WindowShape>
557  CK_TILE_HOST static constexpr auto
559  {
561  host_args.output_ptr,
562  host_args.output_index_ptr,
563  host_args.input_shape,
564  host_args.output_shape,
565  host_args.input_strides,
566  host_args.output_strides,
567  host_args.window_lengths,
568  host_args.window_strides,
569  host_args.window_dilations,
570  host_args.input_left_pads,
571  host_args.input_right_pads};
572  }
573 };
574 
575 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1584
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_least_multiple(X x, Y y)
Definition: math.hpp:155
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1565
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: coordinate_transform.hpp:1594
Host arguments for pooling operations.
Definition: pool_kernel.hpp:16
TensorShape input_strides
Definition: pool_kernel.hpp:51
void * output_ptr
Definition: pool_kernel.hpp:46
WindowShape input_left_pads
Definition: pool_kernel.hpp:56
const void * input_ptr
Definition: pool_kernel.hpp:45
WindowShape window_lengths
Definition: pool_kernel.hpp:53
WindowShape window_strides
Definition: pool_kernel.hpp:54
TensorShape input_shape
Definition: pool_kernel.hpp:49
TensorShape output_strides
Definition: pool_kernel.hpp:52
CK_TILE_HOST PoolHostArgs(const void *input_ptr_, void *output_ptr_, void *output_index_ptr_, TensorShape input_shape_, TensorShape output_shape_, TensorShape input_strides_, TensorShape output_strides_, WindowShape window_lengths_, WindowShape window_strides_, WindowShape window_dilations_, WindowShape input_left_pads_, WindowShape input_right_pads_)
Definition: pool_kernel.hpp:18
TensorShape output_shape
Definition: pool_kernel.hpp:50
WindowShape input_right_pads
Definition: pool_kernel.hpp:57
WindowShape window_dilations
Definition: pool_kernel.hpp:55
void * output_index_ptr
Definition: pool_kernel.hpp:47
Kernel arguments for pooling operations.
Definition: pool_kernel.hpp:63
TensorShape output_shape
Definition: pool_kernel.hpp:68
WindowShape input_right_pads
Definition: pool_kernel.hpp:75
WindowShape window_lengths
Definition: pool_kernel.hpp:71
WindowShape window_dilations
Definition: pool_kernel.hpp:73
TensorShape input_strides
Definition: pool_kernel.hpp:69
const void * input_ptr
Definition: pool_kernel.hpp:64
WindowShape input_left_pads
Definition: pool_kernel.hpp:74
TensorShape input_shape
Definition: pool_kernel.hpp:67
WindowShape window_strides
Definition: pool_kernel.hpp:72
void * output_ptr
Definition: pool_kernel.hpp:65
TensorShape output_strides
Definition: pool_kernel.hpp:70
void * output_index_ptr
Definition: pool_kernel.hpp:66
Definition: pool_kernel.hpp:80
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: pool_kernel.hpp:82
ck_tile::remove_cvref_t< typename Problem::OutDataType > OutDataType
Definition: pool_kernel.hpp:86
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: pool_kernel.hpp:85
static constexpr CK_TILE_HOST auto BlockSize()
Definition: pool_kernel.hpp:91
static constexpr CK_TILE_HOST index_t CalculateGridSize(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition: pool_kernel.hpp:543
static constexpr index_t kBlockSize
Definition: pool_kernel.hpp:89
static CK_TILE_HOST bool IsSupportedArgument(PoolKernelArgs< TensorShape, WindowShape > kargs)
Validates if the given arguments are supported by the pooling kernel.
Definition: pool_kernel.hpp:491
static CK_TILE_DEVICE auto MakeTensorView2D(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition: pool_kernel.hpp:97
static CK_TILE_DEVICE auto MakeTensorView3D(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition: pool_kernel.hpp:233
static constexpr CK_TILE_HOST auto MakeKernelArgs(PoolHostArgs< TensorShape, WindowShape > &host_args)
Create kernel arguments from host arguments.
Definition: pool_kernel.hpp:558
ck_tile::remove_cvref_t< typename Problem::InDataType > InDataType
Definition: pool_kernel.hpp:84
ck_tile::remove_cvref_t< typename Problem::IndexDataType > IndexDataType
Definition: pool_kernel.hpp:87
CK_TILE_DEVICE void operator()(PoolKernelArgs< TensorShape, WindowShape > kargs) const
Definition: pool_kernel.hpp:383
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: pool_kernel.hpp:81
Definition: integral_constant.hpp:13
Definition: null_tensor.hpp:9
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define CK_TILE_ENV(name)
Definition: env.hpp:145