14 template <
typename TensorShape,
typename WindowShape>
20 void* output_index_ptr_,
21 TensorShape input_shape_,
22 TensorShape output_shape_,
23 TensorShape input_strides_,
24 TensorShape output_strides_,
25 WindowShape window_lengths_,
26 WindowShape window_strides_,
27 WindowShape window_dilations_,
28 WindowShape input_left_pads_,
29 WindowShape input_right_pads_)
61 template <
typename TensorShape,
typename WindowShape>
78 template <
typename Problem_,
typename Policy_ = PoolDefaultPolicy>
96 template <
typename TensorShape,
typename WindowShape>
99 using S =
typename Problem::BlockShape;
102 static_assert(TensorShape::size() == 4,
"2D pooling requires 4D input tensor (N,H,W,C)");
103 static_assert(WindowShape::size() == 2,
"2D pooling requires 2D window shape (Y,X)");
131 const index_t MRaw = N * Ho * Wo * C;
136 auto reduce_op =
typename Problem::ReduceOp{};
163 const auto merged_embed_in_desc =
171 merged_embed_in_desc,
185 const auto out_desc_padded =
193 type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
195 type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
197 auto in_buffer_view = make_buffer_view<address_space_enum::global>(
199 in_desc.get_element_space_size(),
201 const auto in_tensor_padded =
202 tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
205 auto out_buffer_view = make_buffer_view<address_space_enum::global>(
207 out_desc.get_element_space_size(),
209 const auto out_tensor_padded =
210 tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
213 if constexpr(Problem::kOutputIndex)
215 auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
217 out_desc.get_element_space_size(),
219 const auto out_index_tensor_padded =
220 tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
221 out_index_buffer_view, out_desc_padded};
223 return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
232 template <
typename TensorShape,
typename WindowShape>
235 using S =
typename Problem::BlockShape;
238 static_assert(TensorShape::size() == 5,
"3D pooling requires 5D input tensor (N,D,H,W,C)");
239 static_assert(WindowShape::size() == 3,
"3D pooling requires 3D window shape (Z,Y,X)");
274 const index_t MRaw = N * Do * Ho * Wo * C;
275 const index_t KRaw = Z * Y * X;
279 auto reduce_op =
typename Problem::ReduceOp{};
320 merged_embed_in_desc,
334 const auto out_desc_padded =
342 type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
344 type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
346 auto in_buffer_view = make_buffer_view<address_space_enum::global>(
348 in_desc.get_element_space_size(),
350 const auto in_tensor_padded =
351 tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
354 auto out_buffer_view = make_buffer_view<address_space_enum::global>(
356 out_desc.get_element_space_size(),
358 const auto out_tensor_padded =
359 tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
362 if constexpr(Problem::kOutputIndex)
364 auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
366 out_desc.get_element_space_size(),
368 const auto out_index_tensor_padded =
369 tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
370 out_index_buffer_view, out_desc_padded};
372 return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
382 template <
typename TensorShape,
typename WindowShape>
385 using S =
typename Problem::BlockShape;
388 static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
389 "Only 2D and 3D pooling operations are supported");
391 const auto iM = get_block_id() * S::Block_M;
394 auto [in_tensor_padded, out_tensor_padded, out_index_tensor_padded] = [&]() {
395 if constexpr(WindowShape::size() == 2)
397 else if constexpr(WindowShape::size() == 3)
400 static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
401 "Unsupported WindowShape rank: only 2D or 3D pooling is supported");
404 auto reduce_op =
typename Problem::ReduceOp{};
409 Policy::template MakeXBlockTileDistribution<Problem>());
412 __shared__
char smem[Policy::template GetSmemSize<Problem>()];
414 const auto reduce_len =
415 in_tensor_padded.get_tensor_descriptor().get_lengths().at(
number<1>{});
419 auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
420 auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
421 auto block_reduce2d_cross_warp = Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
423 using XTensorTile = decltype(
load_tile(x_window));
424 auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
425 set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
427 if constexpr(Problem::kOutputIndex)
429 auto y_index_window =
433 block_reduce2d.template MakeYIndexBlockTile<XTensorTile, IndexDataType>();
440 auto index_calculator = [&](
const auto& x_indices) {
442 const auto global_M = x_indices.at(
number<0>{}) + iM;
443 const auto global_N = (k_tile * S::Block_N) + x_indices.at(
number<1>{});
444 return in_tensor_padded.get_tensor_descriptor().calculate_offset(
448 block_reduce2d(x_tile, y_tile, y_index_tile, reduce_op, index_calculator);
452 block_reduce2d_sync(y_tile, y_index_tile, reduce_op);
453 if constexpr(Problem::kNeedCrossWarpSync)
455 __shared__
char smem_indices[Policy::template GetIndicesSmemSize<Problem>()];
457 block_reduce2d_cross_warp(y_tile, y_index_tile, smem, smem_indices, reduce_op);
460 store_tile(y_window, cast_tile<OutDataType>(y_tile));
461 store_tile(y_index_window, cast_tile<IndexDataType>(y_index_tile));
466 for(
int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
469 block_reduce2d(x_tile, y_tile, reduce_op);
473 block_reduce2d_sync(y_tile, reduce_op);
474 block_reduce2d_cross_warp(y_tile, smem, reduce_op);
476 store_tile(y_window, cast_tile<OutDataType>(y_tile));
490 template <
typename TensorShape,
typename WindowShape>
493 constexpr
index_t InputRank = TensorShape::size();
494 constexpr
index_t OutputRank = TensorShape::size();
495 constexpr
index_t WindowRank = WindowShape::size();
498 if constexpr(WindowRank != 2 && WindowRank != 3)
508 if constexpr((WindowRank == 2 && InputRank != 4) || (WindowRank == 3 && InputRank != 5))
512 CK_TILE_ERROR(
"Input tensor rank doesn't match window dimensions!");
522 CK_TILE_ERROR(
"Input tensor's channel dimension must have stride 1!");
531 CK_TILE_ERROR(
"Output tensor's channel dimension must have stride 1!");
541 template <
typename TensorShape,
typename WindowShape>
545 using S =
typename Problem::BlockShape;
552 return (M + S::Block_M - 1) / S::Block_M;
556 template <
typename TensorShape,
typename WindowShape>
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1584
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_least_multiple(X x, Y y)
Definition: math.hpp:155
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1565
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: coordinate_transform.hpp:1594
Host arguments for pooling operations.
Definition: pool_kernel.hpp:16
TensorShape input_strides
Definition: pool_kernel.hpp:51
void * output_ptr
Definition: pool_kernel.hpp:46
WindowShape input_left_pads
Definition: pool_kernel.hpp:56
const void * input_ptr
Definition: pool_kernel.hpp:45
WindowShape window_lengths
Definition: pool_kernel.hpp:53
WindowShape window_strides
Definition: pool_kernel.hpp:54
TensorShape input_shape
Definition: pool_kernel.hpp:49
TensorShape output_strides
Definition: pool_kernel.hpp:52
CK_TILE_HOST PoolHostArgs(const void *input_ptr_, void *output_ptr_, void *output_index_ptr_, TensorShape input_shape_, TensorShape output_shape_, TensorShape input_strides_, TensorShape output_strides_, WindowShape window_lengths_, WindowShape window_strides_, WindowShape window_dilations_, WindowShape input_left_pads_, WindowShape input_right_pads_)
Definition: pool_kernel.hpp:18
TensorShape output_shape
Definition: pool_kernel.hpp:50
WindowShape input_right_pads
Definition: pool_kernel.hpp:57
WindowShape window_dilations
Definition: pool_kernel.hpp:55
void * output_index_ptr
Definition: pool_kernel.hpp:47
Kernel arguments for pooling operations.
Definition: pool_kernel.hpp:63
TensorShape output_shape
Definition: pool_kernel.hpp:68
WindowShape input_right_pads
Definition: pool_kernel.hpp:75
WindowShape window_lengths
Definition: pool_kernel.hpp:71
WindowShape window_dilations
Definition: pool_kernel.hpp:73
TensorShape input_strides
Definition: pool_kernel.hpp:69
const void * input_ptr
Definition: pool_kernel.hpp:64
WindowShape input_left_pads
Definition: pool_kernel.hpp:74
TensorShape input_shape
Definition: pool_kernel.hpp:67
WindowShape window_strides
Definition: pool_kernel.hpp:72
void * output_ptr
Definition: pool_kernel.hpp:65
TensorShape output_strides
Definition: pool_kernel.hpp:70
void * output_index_ptr
Definition: pool_kernel.hpp:66
Definition: pool_kernel.hpp:80
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: pool_kernel.hpp:82
ck_tile::remove_cvref_t< typename Problem::OutDataType > OutDataType
Definition: pool_kernel.hpp:86
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: pool_kernel.hpp:85
static constexpr CK_TILE_HOST auto BlockSize()
Definition: pool_kernel.hpp:91
static constexpr CK_TILE_HOST index_t CalculateGridSize(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition: pool_kernel.hpp:543
static constexpr index_t kBlockSize
Definition: pool_kernel.hpp:89
static CK_TILE_HOST bool IsSupportedArgument(PoolKernelArgs< TensorShape, WindowShape > kargs)
Validates if the given arguments are supported by the pooling kernel.
Definition: pool_kernel.hpp:491
static CK_TILE_DEVICE auto MakeTensorView2D(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition: pool_kernel.hpp:97
static CK_TILE_DEVICE auto MakeTensorView3D(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition: pool_kernel.hpp:233
static constexpr CK_TILE_HOST auto MakeKernelArgs(PoolHostArgs< TensorShape, WindowShape > &host_args)
Create kernel arguments from host arguments.
Definition: pool_kernel.hpp:558
ck_tile::remove_cvref_t< typename Problem::InDataType > InDataType
Definition: pool_kernel.hpp:84
ck_tile::remove_cvref_t< typename Problem::IndexDataType > IndexDataType
Definition: pool_kernel.hpp:87
CK_TILE_DEVICE void operator()(PoolKernelArgs< TensorShape, WindowShape > kargs) const
Definition: pool_kernel.hpp:383
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: pool_kernel.hpp:81
Definition: integral_constant.hpp:13
Definition: null_tensor.hpp:9
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define CK_TILE_ENV(name)
Definition: env.hpp:145