/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp Source File
streamk_gemm_tile_partitioner_impl.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
6 namespace ck_tile {
7 
8 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
10  index_t m, index_t n, index_t k, index_t grid)
11  : grid_{grid}, n_{n}
12 {
13  iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
15 
16  bool big_enough = num_tiles_ > grid_;
17  index_t remainder_tiles = num_tiles_ % grid_;
18 
19  if(remainder_tiles)
20  {
21  sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_;
22  sk_tiles_ = min(num_tiles_, sk_tiles_);
23  sk_ctas_ = grid_;
24  total_sk_iters_ = sk_tiles_ * iters_per_tile_;
25 
26  // If there still isn't enough work to saturate all CUs, then just revert to DP only.
27  if(total_sk_iters_ < grid_)
28  {
29  sk_tiles_ = 0;
30  sk_ctas_ = 0;
31  total_sk_iters_ = 0;
32  }
33  }
34  else // Full DP (i.e., no Stream-K)
35  {
36  sk_tiles_ = 0;
37  sk_ctas_ = 0;
38  total_sk_iters_ = 0;
39  }
40 
41  iters_per_sk_cta_ = sk_ctas_ ? total_sk_iters_ / sk_ctas_ : 0;
42  extra_iters_ = sk_ctas_ ? total_sk_iters_ % sk_ctas_ : 0;
43 
44  dp_tiles_ = num_tiles_ - sk_tiles_;
45  total_dp_iters_ = dp_tiles_ * iters_per_tile_;
46 }
47 
48 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
51  index_t acc_element_bytes) const noexcept
52 {
53  return MPerBlock * NPerBlock * acc_element_bytes * sk_ctas_;
54 }
55 
56 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
59  const noexcept
60 {
61  return sizeof(index_t) * sk_ctas_;
62 }
63 
64 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
67  index_t cta_idx) const noexcept
68 {
69  // Compute the number of extra iterations done before this CTA. If the cta_idx is less than
70  // extra_iters, the number of extra iterations before the CTA is exactly the cta_idx. Otherwise,
71  // it is extra_iters.
72  index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
73  return total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
74 }
75 
76 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
77 CK_TILE_DEVICE void
79  index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
80 {
81  iter = get_start_iter(cta_idx);
82  iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
83 }
84 
85 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
88  index_t iter) const noexcept
89 {
90  return iter / iters_per_tile_;
91 }
92 
93 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
96  index_t& tile_iter, index_t& tile_iter_end, index_t tile_idx) const noexcept
97 {
98  tile_iter = tile_idx * iters_per_tile_;
99  tile_iter_end = tile_iter + iters_per_tile_;
100 }
101 
102 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
103 CK_TILE_DEVICE /* static */ index_t
105  index_t iter, index_t tile_iter) noexcept
106 {
107  return iter - tile_iter;
108 }
109 
110 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
111 CK_TILE_DEVICE /* static */ index_t
113  index_t tile_iter, index_t iter_end, index_t tile_iter_end) noexcept
114 {
115  return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
116 }
117 
118 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
121  index_t tile_iter_start, index_t cta_idx) const noexcept
122 {
123  tile_iter_start = tile_iter_start - (dp_tiles_ * iters_per_tile_);
124 
125  // Compute how many WGs fit before this tile starts assuming each WG does an
126  // extra_iter
127  const index_t num_extra_iter_ctas = tile_iter_start / (iters_per_sk_cta_ + 1);
128  // Compute how many WGs fit before this tile starts excluding extra iters
129  const index_t num_non_extra_iter_ctas = (tile_iter_start - extra_iters_) / iters_per_sk_cta_;
130  // Compute the CTA idx for the CTA that starts this tile
131  const index_t coop_group_start =
132  num_extra_iter_ctas < extra_iters_ ? num_extra_iter_ctas : num_non_extra_iter_ctas;
133  return cta_idx - coop_group_start;
134 }
135 
136 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
137 CK_TILE_DEVICE auto
139  index_t tile_idx) const noexcept -> tuple<index_t, index_t>
140 {
141  const index_t n_macro_tiles = integer_divide_ceil(n_, NPerBlock);
142 
143  const index_t im = amd_wave_read_first_lane(tile_idx / n_macro_tiles);
144  const index_t in = amd_wave_read_first_lane(tile_idx - im * n_macro_tiles);
145  return make_tuple(im, in);
146 }
147 
148 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
151  index_t acc_element_bytes) const noexcept
152 {
153  if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction ||
154  ReductionStrategy == StreamKReductionStrategy::TreeReduction)
155  {
156 
157  return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
158  }
159  else // ReductionStrategy is Atomics
160  {
161  return 0;
162  }
163 }
164 
165 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
168  const noexcept
169 {
170  return num_tiles_;
171 }
172 
173 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
176 {
177  return grid_;
178 }
179 
180 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
183 {
184  return dp_tiles_;
185 }
186 
187 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
190 {
191  return sk_tiles_;
192 }
193 
194 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
197 {
198  return sk_ctas_;
199 }
200 
201 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
204  const noexcept
205 {
206  return total_sk_iters_;
207 }
208 
209 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
212  const noexcept
213 {
214  return iters_per_tile_;
215 }
216 
217 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
220  const noexcept
221 {
222  return iters_per_sk_cta_;
223 }
224 
225 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
228  const noexcept
229 {
230  return extra_iters_;
231 }
232 
233 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
236  const noexcept
237 {
238  return total_dp_iters_;
239 }
240 
241 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
244 {
245  return n_;
246 }
247 
248 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
251  const noexcept
252 {
253  // In the case of non-atomic reduction or data-parallel (DP) only, there will always be 1
254  // workgroup writing final results to a given macro tile in C.
255  int num_wgs_per_tile = 1;
256 
257  // Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
258  if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
259  {
260  // If we have DP and SK tiles, this is DP+2TSK which guarantees at most 2 workgroups per
261  // tile. We only need to check that dp_tiles is greater than zero since we know we have SK
262  // workgroups.
263  if(dp_tiles_ > 0)
264  {
265  num_wgs_per_tile = 2;
266  }
267  else
268  {
269  ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
270  // Estimate the number of workgroups per macro tile.
271  num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
272  ((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
273  }
274  }
275 
276  return std::max(num_wgs_per_tile, 1);
277 }
278 
279 template <typename BlockGemmShapeType,
280  StreamKReductionStrategy ReductionStrategyType,
281  bool Persistent>
283 
284 // child class for Persistent Tile Partitioner
285 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
288  : StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
289 { // inherit from base constructor
290  dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
291  extra_dp_tiles_ = this->dp_tiles_ % this->grid_;
292 }
293 
294 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
295 CK_TILE_HOST auto
297  -> dim3
298 {
299  if(extra_dp_tiles_ == 0)
300  {
301  return dim3(this->grid_, 1, 1);
302  }
303  else
304  {
305  return dim3(this->num_tiles_, 1, 1);
306  }
307 }
308 
309 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
312  const noexcept
313 {
314  return dp_tiles_per_cta_;
315 }
316 
317 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
320  const noexcept
321 {
322  return extra_dp_tiles_;
323 }
324 
325 // child class for Non-Persistent Tile Partitioner
326 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
329  : StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
330 { // inherit from base constructor
331  dp_ctas_ = this->dp_tiles_;
332  dp_start_block_idx_ = 0;
333  sk_start_block_idx_ = this->dp_tiles_;
334 }
335 
336 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
337 CK_TILE_HOST auto
339  -> dim3
340 {
341  return dim3(dp_ctas_ + this->get_sk_ctas(), 1, 1);
342 }
343 
344 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
347  const noexcept
348 {
349  return dp_ctas_;
350 }
351 
352 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
355  const noexcept
356 {
357  return dp_start_block_idx_;
358 }
359 
360 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
363  const noexcept
364 {
365  return sk_start_block_idx_;
366 }
367 
368 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
StreamKReductionStrategy
Definition: streamk_common.hpp:10
@ TreeReduction
Definition: streamk_common.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:206
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
@ Atomic
Definition: block_to_ctile_map.hpp:1012
@ Reduction
Definition: block_to_ctile_map.hpp:1013
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:301
Stream-K tile partitioner base class.
Definition: streamk_gemm_tile_partitioner.hpp:24
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept
Calculates the total space needed for the flags buffer.
Definition: streamk_gemm_tile_partitioner_impl.hpp:58
CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept
Calculates the total space needed for the partials buffer.
Definition: streamk_gemm_tile_partitioner_impl.hpp:50
index_t grid_
Definition: streamk_gemm_tile_partitioner.hpp:217
static constexpr index_t KPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:28
static constexpr index_t NPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:27
static constexpr index_t MPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:26
index_t num_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:216
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid)
Definition: streamk_gemm_tile_partitioner_impl.hpp:9
index_t dp_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:218
Template for the Stream-K tile partitioner derived struct.
Definition: streamk_gemm_tile_partitioner.hpp:252
Definition: tuple.hpp:192