/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp Source File
streamk_gemm_tile_partitioner_impl.hpp
Go to the documentation of this file.
1 // Copyright © Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 #pragma once
5 namespace ck_tile {
6 
7 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
9  index_t m, index_t n, index_t k, index_t grid)
10  : grid_{grid}, n_{n}
11 {
12  iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
14 
15  bool big_enough = num_tiles_ > grid_;
16  index_t remainder_tiles = num_tiles_ % grid_;
17 
18  if(remainder_tiles)
19  {
20  sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_;
21  sk_tiles_ = min(num_tiles_, sk_tiles_);
22  sk_ctas_ = grid_;
23  total_sk_iters_ = sk_tiles_ * iters_per_tile_;
24 
25  // If there still isn't enough work to saturate all CUs, then just revert to DP only.
26  if(total_sk_iters_ < grid_)
27  {
28  sk_tiles_ = 0;
29  sk_ctas_ = 0;
30  total_sk_iters_ = 0;
31  }
32  }
33  else // Full DP (i.e., no Stream-K)
34  {
35  sk_tiles_ = 0;
36  sk_ctas_ = 0;
37  total_sk_iters_ = 0;
38  }
39 
40  iters_per_sk_cta_ = sk_ctas_ ? total_sk_iters_ / sk_ctas_ : 0;
41  extra_iters_ = sk_ctas_ ? total_sk_iters_ % sk_ctas_ : 0;
42 
43  dp_tiles_ = num_tiles_ - sk_tiles_;
44  total_dp_iters_ = dp_tiles_ * iters_per_tile_;
45 }
46 
47 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
49 StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_partials_buffer_size(
50  index_t acc_element_bytes) const noexcept
51 {
52  return MPerBlock * NPerBlock * acc_element_bytes * sk_ctas_;
53 }
54 
55 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
57 StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags_buffer_size()
58  const noexcept
59 {
60  return sizeof(index_t) * sk_ctas_;
61 }
62 
63 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
66  index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
67 {
68  index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
69  iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
70  iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
71 }
72 
73 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
76  index_t iter) const noexcept
77 {
78  return iter / iters_per_tile_;
79 }
80 
81 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
82 CK_TILE_DEVICE void
84  index_t& tile_iter, index_t& tile_iter_end, index_t tile_idx) const noexcept
85 {
86  tile_iter = tile_idx * iters_per_tile_;
87  tile_iter_end = tile_iter + iters_per_tile_;
88 }
89 
90 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
91 CK_TILE_DEVICE /* static */ index_t
93  index_t iter, index_t tile_iter) noexcept
94 {
95  return iter - tile_iter;
96 }
97 
98 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
99 CK_TILE_DEVICE /* static */ index_t
101  index_t tile_iter, index_t iter_end, index_t tile_iter_end) noexcept
102 {
103  return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
104 }
105 
106 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
107 CK_TILE_DEVICE auto
109  index_t tile_idx) const noexcept -> tuple<index_t, index_t>
110 {
111  const index_t n_macro_tiles = integer_divide_ceil(n_, NPerBlock);
112 
113  const index_t im = amd_wave_read_first_lane(tile_idx / n_macro_tiles);
114  const index_t in = amd_wave_read_first_lane(tile_idx - im * n_macro_tiles);
115  return make_tuple(im, in);
116 }
117 
118 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
121  index_t acc_element_bytes) const noexcept
122 {
123  if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
124  {
125 
126  return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
127  }
128  else // ReductionStrategy is Atomics
129  {
130  return 0;
131  }
132 }
133 
134 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
137  const noexcept
138 {
139  return num_tiles_;
140 }
141 
142 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
145 {
146  return grid_;
147 }
148 
149 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
152 {
153  return dp_tiles_;
154 }
155 
156 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
159 {
160  return sk_tiles_;
161 }
162 
163 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
166 {
167  return sk_ctas_;
168 }
169 
170 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
173  const noexcept
174 {
175  return total_sk_iters_;
176 }
177 
178 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
181  const noexcept
182 {
183  return iters_per_tile_;
184 }
185 
186 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
189  const noexcept
190 {
191  return iters_per_sk_cta_;
192 }
193 
194 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
197  const noexcept
198 {
199  return extra_iters_;
200 }
201 
202 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
205  const noexcept
206 {
207  return total_dp_iters_;
208 }
209 
210 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
213 {
214  return n_;
215 }
216 
217 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
220  const noexcept
221 {
222  // In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
223  // writing final results to a given macro tile in C.
224  int num_wgs_per_tile = 1;
225 
226  // Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
227  if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
228  {
229  ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
230  // Estimate the number of workgroups per macro tile.
231  num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
232  ((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
233  }
234 
235  return std::max(num_wgs_per_tile, 1);
236 }
237 
238 template <typename BlockGemmShapeType,
239  StreamKReductionStrategy ReductionStrategyType,
240  bool Persistent>
242 
243 // child class for Persistent Tile Partitioner
244 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
249  ck_tile::index_t grid)
250  : StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
251 { // inherit from base constructor
252  dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
253  extra_dp_tiles_ = this->dp_tiles_ % this->grid_;
254 }
255 
256 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
257 CK_TILE_HOST auto
259  const noexcept -> dim3
260 {
261  if(extra_dp_tiles_ == 0)
262  {
263  return dim3(this->grid_, 1, 1);
264  }
265  else
266  {
267  return dim3(this->num_tiles_, 1, 1);
268  }
269 }
270 
271 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
274  const noexcept
275 {
276  return dp_tiles_per_cta_;
277 }
278 
279 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
282  const noexcept
283 {
284  return extra_dp_tiles_;
285 }
286 
287 // child class for Non-Persistent Tile Partitioner
288 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
293  ck_tile::index_t grid)
294  : StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
295 { // inherit from base constructor
296  dp_ctas_ = this->dp_tiles_;
297  dp_start_block_idx_ = 0;
298  sk_start_block_idx_ = this->dp_tiles_;
299 }
300 
301 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
302 CK_TILE_HOST auto
304  const noexcept -> dim3
305 {
306  return dim3(dp_ctas_ + this->get_sk_ctas(), 1, 1);
307 }
308 
309 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
312  const noexcept
313 {
314  return dp_ctas_;
315 }
316 
317 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
320  get_dp_start_block_idx() const noexcept
321 {
322  return dp_start_block_idx_;
323 }
324 
325 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
328  get_sk_start_block_idx() const noexcept
329 {
330  return sk_start_block_idx_;
331 }
332 
333 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
StreamKReductionStrategy
Definition: streamk_common.hpp:10
int32_t index_t
Definition: integer.hpp:9
ck_tile::index_t estimate_num_wgs_per_tile(index_t sk_ctas, index_t iters_per_sk_cta, index_t iters_per_tile)
Estimates the number of Stream-K workgroups per macro tile in the C tensor.
Definition: streamk_common.hpp:27
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
@ Atomic
Definition: block_to_ctile_map.hpp:1012
@ Reduction
Definition: block_to_ctile_map.hpp:1013
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
Template for the Stream-K tile partitioner derived struct.
Definition: streamk_gemm_tile_partitioner.hpp:231
Stream-K tile partitioner base class.
Definition: streamk_gemm_tile_partitioner.hpp:24
index_t grid_
Definition: streamk_gemm_tile_partitioner.hpp:196
static constexpr index_t KPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:29
static constexpr index_t NPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:28
static constexpr index_t MPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:27
index_t num_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:195
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid)
Definition: streamk_gemm_tile_partitioner_impl.hpp:8
index_t dp_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:197
Definition: tuple.hpp:192