/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp Source File
reduce2d_shape.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 template <typename BlockWarps, // num warps along seq<M, N>
11  typename BlockTile, // block size, seq<M, N>
12  typename WarpTile, // warp size, seq<M, N>
13  typename ThreadTile> // contiguous pixels(vector size) along seq<M, N>
15 {
16  static constexpr index_t Block_M = BlockTile::at(number<0>{});
17  static constexpr index_t Block_N = BlockTile::at(number<1>{});
18 
19  static constexpr index_t Warp_M = WarpTile::at(number<0>{});
20  static constexpr index_t Warp_N = WarpTile::at(number<1>{});
21 
22  static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{});
23  static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
24 
25  static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
26  static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
27 
30 
31  static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
32  static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
33 
34  static constexpr index_t BlockSize =
35  ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
36 };
37 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:979
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
Definition: reduce2d_shape.hpp:15
static constexpr index_t Warp_N
Definition: reduce2d_shape.hpp:20
static constexpr index_t Block_M
Definition: reduce2d_shape.hpp:16
static constexpr index_t WarpPerBlock_M
Definition: reduce2d_shape.hpp:25
static constexpr index_t ThreadTile_M
Definition: reduce2d_shape.hpp:22
static constexpr index_t ThreadTile_N
Definition: reduce2d_shape.hpp:23
static constexpr index_t Repeat_N
Definition: reduce2d_shape.hpp:32
static constexpr index_t ThreadPerWarp_N
Definition: reduce2d_shape.hpp:29
static constexpr index_t WarpPerBlock_N
Definition: reduce2d_shape.hpp:26
static constexpr index_t Warp_M
Definition: reduce2d_shape.hpp:19
static constexpr index_t BlockSize
Definition: reduce2d_shape.hpp:34
static constexpr index_t ThreadPerWarp_M
Definition: reduce2d_shape.hpp:28
static constexpr index_t Block_N
Definition: reduce2d_shape.hpp:17
static constexpr index_t Repeat_M
Definition: reduce2d_shape.hpp:31
Definition: integral_constant.hpp:13
Definition: math.hpp:98