/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp Source File
pool_default_policy.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
13 {
14  template <typename Problem>
16  {
17  using S = typename Problem::BlockShape;
20  sequence<>,
21  tuple<
28  }
29 
30  template <typename Problem>
31  CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
32  {
33  using P_ = BlockReduce2dProblem<typename Problem::InDataType,
34  typename Problem::ComputeDataType,
35  typename Problem::BlockShape,
36  Problem::kOutputIndex>;
37  return BlockReduce2d<P_>{};
38  }
39 
40  template <typename Problem>
41  CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
42  {
43  using P_ = BlockReduce2dProblem<typename Problem::InDataType,
44  typename Problem::ComputeDataType,
45  typename Problem::BlockShape,
46  Problem::kOutputIndex>;
47  return BlockReduce2dSync<P_>{};
48  }
49 
50  template <typename Problem>
52  {
53  using P_ = BlockReduce2dProblem<typename Problem::InDataType,
54  typename Problem::ComputeDataType,
55  typename Problem::BlockShape,
56  Problem::kOutputIndex>;
58  }
59 
60  template <typename Problem>
62  {
63  if constexpr(Problem::kNeedCrossWarpSync)
64  {
65  using P_ = BlockReduce2dProblem<typename Problem::InDataType,
66  typename Problem::ComputeDataType,
67  typename Problem::BlockShape,
68  Problem::kOutputIndex>;
69 
70  using block_reduce2d = BlockReduce2d<P_>;
71  using x_block_tile =
72  decltype(make_static_distributed_tensor<typename Problem::InDataType>(
73  MakeXBlockTileDistribution<Problem>()));
74  using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
75 
76  return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
77  }
78  else
79  {
80  return 1; // zero size arrays are an extension
81  }
82  }
83 
84  template <typename Problem>
86  {
87 
88  using P_ = BlockReduce2dProblem<typename Problem::InDataType,
89  typename Problem::ComputeDataType,
90  typename Problem::BlockShape,
91  Problem::kOutputIndex>;
92 
93  using block_reduce2d = BlockReduce2d<P_>;
94  using x_block_tile = decltype(make_static_distributed_tensor<typename Problem::InDataType>(
95  MakeXBlockTileDistribution<Problem>()));
96  using y_index_block_tile = decltype(block_reduce2d::template MakeYIndexBlockTile<
97  x_block_tile,
98  typename Problem::IndexDataType>());
99 
100  return GetBlockReduce2dCrossWarpSync<Problem>()
101  .template GetIndicesSmemSize<y_index_block_tile>();
102  }
103 };
104 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
Definition: block_reduce2d.hpp:334
Definition: block_reduce2d.hpp:46
Definition: block_reduce2d_problem.hpp:15
Definition: block_reduce2d.hpp:224
Definition: pool_default_policy.hpp:13
static constexpr CK_TILE_HOST_DEVICE index_t GetIndicesSmemSize()
Definition: pool_default_policy.hpp:85
static constexpr CK_TILE_DEVICE auto MakeXBlockTileDistribution()
Definition: pool_default_policy.hpp:15
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2dSync()
Definition: pool_default_policy.hpp:41
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2dCrossWarpSync()
Definition: pool_default_policy.hpp:51
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2d()
Definition: pool_default_policy.hpp:31
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: pool_default_policy.hpp:61
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192