/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp Source File#
streamk_gemm_tile_partitioner.hpp
  
Go to the documentation of this file.
Definition: cluster_descriptor.hpp:13
Template for the Stream-K tile partitioner derived struct.
Definition: streamk_gemm_tile_partitioner.hpp:231
Stream-K tile partitioner base class.
Definition: streamk_gemm_tile_partitioner.hpp:24
CK_TILE_HOST_DEVICE index_t get_sk_tiles() const noexcept
Returns the number of tiles in the C tensor that will use the Stream-K approach.
Definition: streamk_gemm_tile_partitioner_impl.hpp:158
CK_TILE_HOST_DEVICE index_t get_total_dp_iters() const noexcept
Returns the total number of DP iterations.
Definition: streamk_gemm_tile_partitioner_impl.hpp:204
CK_TILE_HOST_DEVICE index_t get_num_tiles() const noexcept
Returns the number of macro tiles in the C tensor.
Definition: streamk_gemm_tile_partitioner_impl.hpp:136
CK_TILE_HOST index_t get_workspace_size(index_t acc_element_bytes) const noexcept
Calculates the total space needed for the partials and flags buffers.
Definition: streamk_gemm_tile_partitioner_impl.hpp:120
CK_TILE_DEVICE void get_iter_boundaries(index_t &iter_start, index_t &iter_end, index_t cta_idx) const noexcept
Calculates the start and end iteration given the cta_idx.
Definition: streamk_gemm_tile_partitioner_impl.hpp:65
CK_TILE_DEVICE void get_tile_boundaries(index_t &tile_iter_start, index_t &tile_iter_end, index_t tile_idx) const noexcept
Calculates the starting and ending tile boundaries for the given 1D tile index.
Definition: streamk_gemm_tile_partitioner_impl.hpp:83
CK_TILE_DEVICE auto get_output_tile_index(index_t tile_idx) const noexcept -> tuple< index_t, index_t >
Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
Definition: streamk_gemm_tile_partitioner_impl.hpp:108
index_t grid_
Definition: streamk_gemm_tile_partitioner.hpp:196
CK_TILE_HOST_DEVICE index_t get_extra_iters() const noexcept
Returns the remainder resulting from total_sk_iters_ divided by sk_ctas_. When this is non-zero,...
Definition: streamk_gemm_tile_partitioner_impl.hpp:196
static constexpr index_t KPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:29
static CK_TILE_DEVICE index_t get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept
Calculates the workgroup's non-inclusive end iteration that is local to a tile.
Definition: streamk_gemm_tile_partitioner_impl.hpp:100
CK_TILE_HOST_DEVICE index_t get_grid() const noexcept
Returns the maximum number of active workgroups; this is assumed to be number of CUs * occupancy.
Definition: streamk_gemm_tile_partitioner_impl.hpp:144
static constexpr StreamKReductionStrategy ReductionStrategy
Definition: streamk_gemm_tile_partitioner.hpp:30
CK_TILE_DEVICE index_t get_tile_index(index_t iter_start) const noexcept
Calculates the 1D tile index in the C tensor for a workgroup.
Definition: streamk_gemm_tile_partitioner_impl.hpp:75
CK_TILE_HOST_DEVICE index_t get_total_sk_iters() const noexcept
Returns the total number of Stream-K iterations.
Definition: streamk_gemm_tile_partitioner_impl.hpp:172
CK_TILE_HOST_DEVICE index_t get_dp_tiles() const noexcept
Returns the number of tiles in the C tensor that will use the data-parallel (DP) approach.
Definition: streamk_gemm_tile_partitioner_impl.hpp:151
BlockGemmShapeType BlockGemmShape
Definition: streamk_gemm_tile_partitioner.hpp:25
CK_TILE_HOST_DEVICE index_t get_iters_per_tile() const noexcept
Returns the total number of iterations per tile in the C tensor. In other words, this is the total nu...
Definition: streamk_gemm_tile_partitioner_impl.hpp:180
CK_TILE_HOST_DEVICE index_t get_n() const noexcept
Returns the n dimension for the GEMM problem.
Definition: streamk_gemm_tile_partitioner_impl.hpp:212
static constexpr index_t NPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:28
static constexpr index_t MPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:27
index_t num_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:195
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid)
Definition: streamk_gemm_tile_partitioner_impl.hpp:8
CK_TILE_HOST_DEVICE index_t get_iters_per_sk_cta() const noexcept
Returns the total number of Stream-K iterations for each sk_cta. This is the lower bound (i....
Definition: streamk_gemm_tile_partitioner_impl.hpp:188
CK_TILE_HOST_DEVICE index_t get_sk_ctas() const noexcept
Returns the number of workgroups that will participate in Stream-K in the sk_tiles_.
Definition: streamk_gemm_tile_partitioner_impl.hpp:165
static CK_TILE_DEVICE index_t get_local_iter(index_t iter_start, index_t tile_iter_start) noexcept
Calculates the workgroup's starting iteration that is local to a tile.
Definition: streamk_gemm_tile_partitioner_impl.hpp:92
CK_TILE_HOST index_t estimate_num_wgs_per_tile() const noexcept
Returns an estimate of the number of workgroups writing to the same macro tile in C.
Definition: streamk_gemm_tile_partitioner_impl.hpp:219
index_t dp_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:197
Definition: tuple.hpp:192