/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/common/streamk_common.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/common/streamk_common.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/common/streamk_common.hpp Source File
streamk_common.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
10 {
11  Atomic = 0u,
12  Reduction = 1u
13 };
14 
25 template <ck_tile::StreamKReductionStrategy ReductionStrategy>
27 estimate_num_wgs_per_tile(index_t sk_ctas, index_t iters_per_sk_cta, index_t iters_per_tile)
28 {
29  // In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
30  // writing final results to a given macro tile in C.
31  int num_wgs_per_tile = 1;
32 
33  // Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
34  if(sk_ctas > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
35  {
36  // Estimate the number of workgroups per macro tile.
37  num_wgs_per_tile =
38  (iters_per_tile / iters_per_sk_cta) + ((iters_per_tile % iters_per_sk_cta) != 0);
39  }
40 
41  return std::max(num_wgs_per_tile, 1);
42 }
43 } // namespace ck_tile
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: cluster_descriptor.hpp:13
StreamKReductionStrategy
Definition: streamk_common.hpp:10
@ Atomic
Definition: streamk_common.hpp:11
@ Reduction
Definition: streamk_common.hpp:12
int32_t index_t
Definition: integer.hpp:9
ck_tile::index_t estimate_num_wgs_per_tile(index_t sk_ctas, index_t iters_per_sk_cta, index_t iters_per_tile)
Estimates the number of Stream-K workgroups per macro tile in the C tensor.
Definition: streamk_common.hpp:27
unsigned int uint32_t
Definition: stdint.h:126