/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp Source File
streamk_gemm_tile_partitioner.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
8 
9 namespace ck_tile {
10 
21 template <typename BlockGemmShapeType,
24 {
25 
26  static constexpr index_t MPerBlock = BlockGemmShapeType::kM;
27  static constexpr index_t NPerBlock = BlockGemmShapeType::kN;
28  static constexpr index_t KPerBlock = BlockGemmShapeType::kK;
29  static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType;
32  : memory_operation_enum::set;
33 
35 
42  CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept;
43 
50 
51  public:
60  CK_TILE_DEVICE index_t get_start_iter(index_t cta_idx) const noexcept;
61 
74  CK_TILE_DEVICE void
75  get_iter_boundaries(index_t& iter_start, index_t& iter_end, index_t cta_idx) const noexcept;
76 
83  CK_TILE_DEVICE index_t get_tile_index(index_t iter_start) const noexcept;
84 
94  CK_TILE_DEVICE void get_tile_boundaries(index_t& tile_iter_start,
95  index_t& tile_iter_end,
96  index_t tile_idx) const noexcept;
97 
107  CK_TILE_DEVICE static index_t get_local_iter(index_t iter_start,
108  index_t tile_iter_start) noexcept;
109 
119  CK_TILE_DEVICE static index_t
120  get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
121 
130  index_t cta_idx) const noexcept;
131 
138  CK_TILE_DEVICE auto
139  get_output_tile_index(index_t tile_idx) const noexcept -> tuple<index_t, index_t>;
140 
147  CK_TILE_HOST_DEVICE index_t get_workspace_size(index_t acc_element_bytes) const noexcept;
148 
152  CK_TILE_HOST_DEVICE index_t get_num_tiles() const noexcept;
153 
158  CK_TILE_HOST_DEVICE index_t get_grid() const noexcept;
159 
164  CK_TILE_HOST_DEVICE index_t get_dp_tiles() const noexcept;
165 
169  CK_TILE_HOST_DEVICE index_t get_sk_tiles() const noexcept;
170 
174  CK_TILE_HOST_DEVICE index_t get_sk_ctas() const noexcept;
175 
180 
186 
192 
199 
204 
208  CK_TILE_HOST_DEVICE index_t get_n() const noexcept;
209 
214 
215  protected:
219 
220  private:
225  index_t full_tiles_ = 1;
226  index_t sk_tiles_;
227  index_t sk_ctas_;
228  index_t total_sk_iters_;
229  index_t iters_per_tile_;
230  index_t iters_per_sk_cta_;
231  index_t extra_iters_;
232  index_t total_dp_iters_;
233  index_t n_;
234 };
235 
249 template <typename BlockGemmShapeType,
250  StreamKReductionStrategy ReductionStrategyType,
251  bool Persistent>
253 
265 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
266 struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>
267  : StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>
268 {
272  ck_tile::index_t grid);
273 
274  public:
275  static constexpr bool PERSISTENT = true;
283  CK_TILE_HOST auto grid_size() const noexcept -> dim3;
284 
288  CK_TILE_HOST_DEVICE index_t get_dp_tiles_per_cta() const noexcept;
289 
294  CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept;
295 
296  protected:
297  index_t dp_tiles_per_cta_;
298  index_t extra_dp_tiles_;
299 };
300 
312 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
313 struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>
314  : StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>
315 {
319  ck_tile::index_t grid);
320 
321  public:
322  static constexpr bool PERSISTENT = false;
330  CK_TILE_HOST auto grid_size() const noexcept -> dim3;
331 
335  CK_TILE_HOST_DEVICE index_t get_dp_ctas() const noexcept;
336 
340  CK_TILE_HOST_DEVICE index_t get_dp_start_block_idx() const noexcept;
341 
345  CK_TILE_HOST_DEVICE index_t get_sk_start_block_idx() const noexcept;
346 
347  protected:
348  index_t dp_ctas_;
349  index_t dp_start_block_idx_;
350  index_t sk_start_block_idx_;
351 };
352 
353 } // namespace ck_tile
354 
355 #include "streamk_gemm_tile_partitioner_impl.hpp"
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
StreamKReductionStrategy
Definition: streamk_common.hpp:10
int32_t index_t
Definition: integer.hpp:9
@ Atomic
Definition: block_to_ctile_map.hpp:1012
__device__ X atomic_add(X *p_dst, const X &x)
Stream-K tile partitioner base class.
Definition: streamk_gemm_tile_partitioner.hpp:24
CK_TILE_HOST_DEVICE index_t get_sk_tiles() const noexcept
Returns the number of tiles in the C tensor that will use the Stream-K approach.
Definition: streamk_gemm_tile_partitioner_impl.hpp:189
CK_TILE_HOST_DEVICE index_t get_total_dp_iters() const noexcept
Returns the total number of DP iterations.
Definition: streamk_gemm_tile_partitioner_impl.hpp:235
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept
Calculates the total space needed for the flags buffer.
Definition: streamk_gemm_tile_partitioner_impl.hpp:58
CK_TILE_HOST_DEVICE index_t get_num_tiles() const noexcept
Returns the number of macro tiles in the C tensor.
Definition: streamk_gemm_tile_partitioner_impl.hpp:167
CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept
Calculates the total space needed for the partials buffer.
Definition: streamk_gemm_tile_partitioner_impl.hpp:50
CK_TILE_DEVICE void get_iter_boundaries(index_t &iter_start, index_t &iter_end, index_t cta_idx) const noexcept
Calculates the start and end iteration given the cta_idx.
Definition: streamk_gemm_tile_partitioner_impl.hpp:78
CK_TILE_DEVICE void get_tile_boundaries(index_t &tile_iter_start, index_t &tile_iter_end, index_t tile_idx) const noexcept
Calculates the starting and ending tile boundaries for the given 1D tile index.
Definition: streamk_gemm_tile_partitioner_impl.hpp:95
CK_TILE_DEVICE auto get_output_tile_index(index_t tile_idx) const noexcept -> tuple< index_t, index_t >
Calculates the workgroup's 2D tile index in the C tensor given the 1D tile index.
Definition: streamk_gemm_tile_partitioner_impl.hpp:138
index_t grid_
Definition: streamk_gemm_tile_partitioner.hpp:217
CK_TILE_HOST_DEVICE index_t get_extra_iters() const noexcept
Returns the remainder resulting from total_sk_iters_ divided by sk_ctas_. When this is non-zero,...
Definition: streamk_gemm_tile_partitioner_impl.hpp:227
static constexpr index_t KPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:28
static CK_TILE_DEVICE index_t get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept
Calculates the workgroup's non-inclusive end iteration that is local to a tile.
Definition: streamk_gemm_tile_partitioner_impl.hpp:112
CK_TILE_HOST_DEVICE index_t get_grid() const noexcept
Returns the maximum number of active workgroups; this is assumed to be number of CUs * occupancy.
Definition: streamk_gemm_tile_partitioner_impl.hpp:175
static constexpr StreamKReductionStrategy ReductionStrategy
Definition: streamk_gemm_tile_partitioner.hpp:29
CK_TILE_DEVICE index_t get_tile_index(index_t iter_start) const noexcept
Calculates the 1D tile index in the C tensor for a workgroup.
Definition: streamk_gemm_tile_partitioner_impl.hpp:87
CK_TILE_HOST_DEVICE index_t get_total_sk_iters() const noexcept
Returns the total number of Stream-K iterations.
Definition: streamk_gemm_tile_partitioner_impl.hpp:203
CK_TILE_HOST_DEVICE index_t get_dp_tiles() const noexcept
Returns the number of tiles in the C tensor that will use the data-parallel (DP) approach.
Definition: streamk_gemm_tile_partitioner_impl.hpp:182
CK_TILE_HOST_DEVICE index_t get_iters_per_tile() const noexcept
Returns the total number of iterations per tile in the C tensor. In other words, this is the total nu...
Definition: streamk_gemm_tile_partitioner_impl.hpp:211
CK_TILE_HOST_DEVICE index_t get_n() const noexcept
Returns the n dimension for the GEMM problem.
Definition: streamk_gemm_tile_partitioner_impl.hpp:243
static constexpr index_t NPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:27
static constexpr index_t MPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:26
index_t num_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:216
CK_TILE_HOST_DEVICE index_t get_workspace_size(index_t acc_element_bytes) const noexcept
Calculates the total space needed for the partials and flags buffers.
Definition: streamk_gemm_tile_partitioner_impl.hpp:150
static constexpr auto MemoryOperation
Definition: streamk_gemm_tile_partitioner.hpp:30
CK_TILE_DEVICE index_t get_start_iter(index_t cta_idx) const noexcept
Calculates the start iteration for the given the cta_idx.
Definition: streamk_gemm_tile_partitioner_impl.hpp:66
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid)
Definition: streamk_gemm_tile_partitioner_impl.hpp:9
CK_TILE_HOST_DEVICE index_t get_iters_per_sk_cta() const noexcept
Returns the total number of Stream-K iterations for each sk_cta. This is the lower bound (i....
Definition: streamk_gemm_tile_partitioner_impl.hpp:219
CK_TILE_HOST_DEVICE index_t get_sk_ctas() const noexcept
Returns the number of workgroups that will participate in Stream-K in the sk_tiles_.
Definition: streamk_gemm_tile_partitioner_impl.hpp:196
static CK_TILE_DEVICE index_t get_local_iter(index_t iter_start, index_t tile_iter_start) noexcept
Calculates the workgroup's starting iteration that is local to a tile.
Definition: streamk_gemm_tile_partitioner_impl.hpp:104
CK_TILE_HOST index_t estimate_num_wgs_per_tile() const noexcept
Returns an estimate of the number of workgroups writing to the same macro tile in C.
Definition: streamk_gemm_tile_partitioner_impl.hpp:250
index_t dp_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:218
CK_TILE_DEVICE index_t get_tile_local_cta_index(index_t tile_iter_start, index_t cta_idx) const noexcept
Calculates the workgroup's local CTA idx within the given tile.
Definition: streamk_gemm_tile_partitioner_impl.hpp:120
Template for the Stream-K tile partitioner derived struct.
Definition: streamk_gemm_tile_partitioner.hpp:252
Definition: tuple.hpp:192