include/ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp Source File

include/ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp Source File#

Composable Kernel: include/ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp Source File
block_reduce2d_default_policy.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
13 {
14  template <typename Problem>
16  {
17  using S = typename Problem::BlockShape;
20  sequence<>,
27  }
28 
29  template <typename Problem>
30  CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
31  {
32  using P_ = BlockReduce2dProblem<typename Problem::XDataType,
33  typename Problem::ComputeDataType,
34  typename Problem::BlockShape>;
35  return BlockReduce2d<P_>{};
36  }
37 
38  template <typename Problem>
39  CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
40  {
41  using P_ = BlockReduce2dProblem<typename Problem::XDataType,
42  typename Problem::ComputeDataType,
43  typename Problem::BlockShape>;
44  return BlockReduce2dSync<P_>{};
45  }
46 
47  template <typename Problem>
49  {
50  using P_ = BlockReduce2dProblem<typename Problem::XDataType,
51  typename Problem::ComputeDataType,
52  typename Problem::BlockShape>;
54  }
55 
56  template <typename Problem>
58  {
59  if constexpr(Problem::kNeedCrossWarpSync)
60  {
61  using P_ = BlockReduce2dProblem<typename Problem::XDataType,
62  typename Problem::ComputeDataType,
63  typename Problem::BlockShape>;
64 
65  using block_reduce2d = BlockReduce2d<P_>;
66  using x_block_tile =
67  decltype(make_static_distributed_tensor<typename Problem::XDataType>(
68  MakeXBlockTileDistribution<Problem>()));
69  using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
70 
71  return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
72  }
73  else
74  {
75  return 1; // zero size arrays are an extension
76  }
77  }
78 };
79 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
Definition: block_reduce2d.hpp:162
Definition: block_reduce2d_default_policy.hpp:13
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d_default_policy.hpp:57
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2dSync()
Definition: block_reduce2d_default_policy.hpp:39
static constexpr CK_TILE_DEVICE auto MakeXBlockTileDistribution()
Definition: block_reduce2d_default_policy.hpp:15
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2dCrossWarpSync()
Definition: block_reduce2d_default_policy.hpp:48
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2d()
Definition: block_reduce2d_default_policy.hpp:30
Definition: block_reduce2d.hpp:12
Definition: block_reduce2d_problem.hpp:12
Definition: block_reduce2d.hpp:96
Definition: sequence.hpp:52
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192