include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp Source File

include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp Source File#

Composable Kernel: include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp Source File
gemm_tile_partitioner.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
9 #pragma once
10 
11 #include "ck_tile/core.hpp"
12 
13 namespace ck_tile {
14 
19 template <typename BlockGemmShapeType>
21 {
23 
24  static constexpr index_t MPerBlock = BlockGemmShape::kM;
25  static constexpr index_t NPerBlock = BlockGemmShape::kN;
26  static constexpr index_t KPerBlock = BlockGemmShape::kK;
27 
30  [[maybe_unused]] index_t N) noexcept;
31 
39  CK_TILE_HOST static auto
40  GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
41  {
42  const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
43  const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
44  return dim3(GridDimX, GridDimY, 1);
45  }
46 
53  CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
54  {
55  return integer_divide_ceil(K, KPerBlock);
56  }
57 
72  CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept
74  {
75  const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx);
76  const index_t iN = __builtin_amdgcn_readfirstlane(blockIdy);
77  return make_tuple(iM, iN);
78  }
79 };
80 
86 template <typename BlockGemmShape_>
88 {
90 
91  static constexpr index_t MPerBlock = BlockGemmShape::kM;
92  static constexpr index_t NPerBlock = BlockGemmShape::kN;
93  static constexpr index_t KPerBlock = BlockGemmShape::kK;
94 
96 
104  {
105  N_ = N;
106  }
107 
115  CK_TILE_HOST static auto
116  GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
117  {
118  const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
119  const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
120  return GridDimX * GridDimY;
121  }
122 
129  CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
130  {
131  return integer_divide_ceil(K, KPerBlock);
132  }
133 
140  CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx) noexcept
141  -> const tuple<index_t, index_t>
142  {
143  const index_t NBlocks = integer_divide_ceil(N_, NPerBlock);
144 
145  const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlocks);
146  const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - iM * NBlocks);
147  return make_tuple(iM, iN);
148  }
149 
150  private:
151  CK_TILE_DEVICE static index_t N_;
152 };
153 
158 template <typename, typename = void>
160 {
161 };
162 
168 template <typename T>
169 struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIndex(1))>>
171 {
172 };
173 
180 template <typename TilePartitioner,
181  typename = typename std::enable_if_t<HasFnOneArgImpl<TilePartitioner>{}>>
183 {
191  [[nodiscard]] CK_TILE_DEVICE static auto
192  GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept
193  -> const tuple<index_t, index_t>
194  {
195  const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start);
196  return make_tuple(iM, iN);
197  }
198 };
199 
211 template <typename BlockGemmShapeType, index_t GroupNum, index_t M01>
213 {
215 
216  static constexpr index_t MPerBlock = BlockGemmShape::kM;
217  static constexpr index_t NPerBlock = BlockGemmShape::kN;
218  static constexpr index_t KPerBlock = BlockGemmShape::kK;
219 
222  : M(M_), N(N_)
223  {
224  }
225 
233  CK_TILE_HOST static auto
234  GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
235  {
236  const index_t GridDimX = integer_divide_ceil(M, MPerBlock);
237  const index_t GridDimY = integer_divide_ceil(N, NPerBlock);
238  return GridDimX * GridDimY;
239  }
240 
247  CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
248  {
249  return integer_divide_ceil(K, KPerBlock);
250  }
251 
258  CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept
259  -> const tuple<index_t, index_t>
260  {
261  const auto M0 = integer_divide_ceil(M, MPerBlock);
262  const auto N0 = integer_divide_ceil(N, NPerBlock);
263 
264  if(M0 == 1)
265  {
266  return make_tuple(0, block_1d_id);
267  }
268  else if(N0 == 1)
269  {
270  return make_tuple(block_1d_id, 0);
271  }
272  // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
273  else
274  {
275  const auto group_size = integer_divide_ceil(M0 * N0, GroupNum);
276  const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
277  const auto group_id_y = block_1d_id / GroupNum;
278  const auto group_id_x = block_1d_id - group_id_y * GroupNum;
279  const auto remap_block_1d_id =
280  group_id_x <= big_group_num
281  ? group_id_x * group_size + group_id_y
282  : group_id_x * group_size + big_group_num - group_id_x + group_id_y;
283 
284  const index_t idx_M0 = remap_block_1d_id / N0;
285  const index_t idx_N0 = remap_block_1d_id - idx_M0 * N0;
286 
287  const index_t M0_tmp = M0 / M01;
288  const index_t M0_mod_M01 = M0 - M0_tmp * M01;
289 
290  const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01;
291 
292  const index_t idx_M00 = idx_M0 / M01;
293  const index_t idx_M01 = idx_M0 - idx_M00 * M01;
294  const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
295 
340  const index_t N_out = idx_N0_M01_local / M01_adapt;
341  const index_t idx_loc_mod_M01 = idx_N0_M01_local - N_out * M01_adapt;
342 
343  return make_tuple(idx_loc_mod_M01 + idx_M00 * M01, N_out);
344  }
345  }
346 
347  private:
348  index_t M;
349  index_t N;
350 };
351 
352 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
bool_constant< false > false_type
Definition: integral_constant.hpp:55
bool_constant< true > true_type
Definition: integral_constant.hpp:54
Class mapping 1D block index into 2D output tile space.
Definition: gemm_tile_partitioner.hpp:213
static constexpr index_t MPerBlock
Definition: gemm_tile_partitioner.hpp:216
static constexpr index_t KPerBlock
Definition: gemm_tile_partitioner.hpp:218
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition: gemm_tile_partitioner.hpp:247
CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner() noexcept=delete
static CK_TILE_HOST auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> index_t
Calculates GEMM kernel grid size.
Definition: gemm_tile_partitioner.hpp:234
remove_cvref_t< BlockGemmShapeType > BlockGemmShape
Definition: gemm_tile_partitioner.hpp:214
static constexpr index_t NPerBlock
Definition: gemm_tile_partitioner.hpp:217
CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept -> const tuple< index_t, index_t >
Calculate workgroup 1D index mapping into 2D output C-tile space.
Definition: gemm_tile_partitioner.hpp:258
Class providing 1D WGP index mapping into 2D output C-tile space.
Definition: gemm_tile_partitioner.hpp:88
static CK_TILE_HOST auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> index_t
Calculates GEMM kernel grid size.
Definition: gemm_tile_partitioner.hpp:116
CK_TILE_HOST_DEVICE GemmTile1DPartitioner() noexcept=delete
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition: gemm_tile_partitioner.hpp:129
static CK_TILE_DEVICE auto GetOutputTileIndex(index_t blockIdx) noexcept -> const tuple< index_t, index_t >
Calculate workgroup 1D index mapping into 2D output C-tile space.
Definition: gemm_tile_partitioner.hpp:140
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_tile_partitioner.hpp:89
static constexpr index_t MPerBlock
Definition: gemm_tile_partitioner.hpp:91
static constexpr index_t NPerBlock
Definition: gemm_tile_partitioner.hpp:92
static constexpr index_t KPerBlock
Definition: gemm_tile_partitioner.hpp:93
Class providing 2D workgroup index mapping into 2D output GEMM C-tile space.
Definition: gemm_tile_partitioner.hpp:21
static CK_TILE_DEVICE auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple< index_t, index_t >
The function returns 2D output tile space.
Definition: gemm_tile_partitioner.hpp:72
static CK_TILE_HOST auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> dim3
Calculates GEMM kernel grid size.
Definition: gemm_tile_partitioner.hpp:40
remove_cvref_t< BlockGemmShapeType > BlockGemmShape
Definition: gemm_tile_partitioner.hpp:22
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition: gemm_tile_partitioner.hpp:53
static constexpr index_t NPerBlock
Definition: gemm_tile_partitioner.hpp:25
static constexpr index_t KPerBlock
Definition: gemm_tile_partitioner.hpp:26
static constexpr index_t MPerBlock
Definition: gemm_tile_partitioner.hpp:24
CK_TILE_HOST_DEVICE GemmTile2DPartitioner() noexcept=delete
GemmTile1DPartitioner::GetOutputTileIndex's std::false specialization, checking expression validity i...
Definition: gemm_tile_partitioner.hpp:160
Struct used to calculate offseted tile indexes.
Definition: gemm_tile_partitioner.hpp:183
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: tuple.hpp:192