/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/common/generic_2d_block_shape.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/common/generic_2d_block_shape.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/common/generic_2d_block_shape.hpp Source File
generic_2d_block_shape.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 namespace ck_tile {
7 
8 /*
9 // clang-format off
10 
11 4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
12 
13  Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
14  +<----------------------< Repeat_N(2)>--------------------->+
15  | |
16  +<-- <WarpPerBlock_N(2)> -->+
17  Warp_N
18  +--------------+--------------+--------------+--------------+----+----------------+
19  Warp_M | wrap_0 | wrap_1 | | ^ ^
20  +--------------+--------------+ | <WarpPerBlock_M(2)> |
21  | wrap_2 | wrap_3 | | v
22  +--------------+--------------+--------------+--------------+----+ Block_M
23  | | |
24  + + |
25  | | | v
26  +--------------+--------------+--------------+--------------+ +
27 
28  each Warp-tile (e.g 16 thrd per row)
29 
30  Vector_N (contiguous pixels each thrd holds along N, or vector size)
31  +-----------+-----------+-----------+-----------+-----------+
32  | thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
33  +-----------+-----------+-----------+-----------+-----------+
34  | thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
35  +-----------+-----------+-----------+-----------+-----------+
36 // clang-format on
37 */
38 template <typename BlockTile_, // block size, seq<M, N>
39  typename ThreadPerBlock_, // num threads along seq<M, N>
40  typename Vector_> // contiguous pixels(vector size) along seq<M, N>)>
42 {
43  // block size
44  static constexpr index_t Block_M = BlockTile_::at(number<0>{});
45  static constexpr index_t Block_N = BlockTile_::at(number<1>{});
46  static constexpr index_t ThreadPerBlock_M = ThreadPerBlock_::at(number<0>{});
47  static constexpr index_t ThreadPerBlock_N = ThreadPerBlock_::at(number<1>{});
48 
49  // vector size along seq<M, N>
50  static constexpr index_t Vector_M = Vector_::at(number<0>{});
51  static constexpr index_t Vector_N = Vector_::at(number<1>{});
52 
53  // num warps along seq<M, N>, within each block
54  template <bool isHostWave32>
55  static constexpr index_t GetWarpPerBlock_M()
56  {
57  constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
58  constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size;
59  static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % warp_size == 0);
60  constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / warp_size;
61 
62  if constexpr(is_warp_per_row)
63  {
64  static_assert(warp_size % ThreadPerBlock_N == 0);
65  return total_warps * (warp_size / ThreadPerBlock_N);
66  }
67  else
68  {
69  // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
70  return total_warps / (ThreadPerBlock_N / warp_size);
71  }
72  };
73 
74  // num of warps along n
75  template <bool isHostWave32>
76  static constexpr index_t GetWarpPerBlock_N()
77  {
78  constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
79  constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size;
80  if constexpr(is_warp_per_row)
81  {
82  static_assert(warp_size % ThreadPerBlock_N == 0);
83  return 1;
84  }
85  else
86  {
87  static_assert(ThreadPerBlock_N % warp_size == 0);
88  return ThreadPerBlock_N / warp_size;
89  }
90  }
91 
92  static constexpr index_t WarpPerBlock_M = GetWarpPerBlock_M<false>();
93  static constexpr index_t WarpPerBlock_N = GetWarpPerBlock_N<false>();
94 
95  // warp size
99  static_assert(Warp_M % Vector_M == 0);
100  static_assert(Warp_N % Vector_N == 0);
101  static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
102  static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
103 
104  // repeat of each thread along seq<M, N>
105  static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
106  static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
107 
108  // num of threads along seq<M, N>, within each warp
109  static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
110  static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
111 
112  template <bool isHostWave32>
113  static constexpr index_t GetBlockSize()
114  {
115  constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
116  return GetWarpPerBlock_M<isHostWave32>() * GetWarpPerBlock_N<isHostWave32>() * warp_size;
117  }
118 };
119 
120 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: generic_2d_block_shape.hpp:42
static constexpr index_t GetWarpPerBlock_N()
Definition: generic_2d_block_shape.hpp:76
static constexpr index_t Repeat_M
Definition: generic_2d_block_shape.hpp:105
static constexpr index_t GetWarpPerBlock_M()
Definition: generic_2d_block_shape.hpp:55
static constexpr index_t ThreadPerWarp_M
Definition: generic_2d_block_shape.hpp:109
static constexpr index_t WarpPerBlock_N
Definition: generic_2d_block_shape.hpp:93
static constexpr index_t GetBlockSize()
Definition: generic_2d_block_shape.hpp:113
static constexpr index_t ThreadPerBlock_M
Definition: generic_2d_block_shape.hpp:46
static constexpr index_t ThreadPerBlock_N
Definition: generic_2d_block_shape.hpp:47
static constexpr index_t ThreadPerWarp_N
Definition: generic_2d_block_shape.hpp:110
static constexpr index_t Block_M
Definition: generic_2d_block_shape.hpp:44
static constexpr index_t Warp_N
Definition: generic_2d_block_shape.hpp:98
static constexpr index_t BlockSize
Definition: generic_2d_block_shape.hpp:96
static constexpr index_t Repeat_N
Definition: generic_2d_block_shape.hpp:106
static constexpr index_t Block_N
Definition: generic_2d_block_shape.hpp:45
static constexpr index_t Warp_M
Definition: generic_2d_block_shape.hpp:97
static constexpr index_t WarpPerBlock_M
Definition: generic_2d_block_shape.hpp:92
static constexpr index_t Vector_N
Definition: generic_2d_block_shape.hpp:51
static constexpr index_t Vector_M
Definition: generic_2d_block_shape.hpp:50
Definition: integral_constant.hpp:13