/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp Source File
gemm_tile_partitioner.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
9 #pragma once
10 
11 #include "ck_tile/core.hpp"
12 #include "ck_tile/ops/common.hpp"
13 
14 namespace ck_tile {
15 
20 template <typename BlockGemmShapeType>
22 {
24 
25  static constexpr index_t MPerBlock = BlockGemmShape::kM;
26  static constexpr index_t NPerBlock = BlockGemmShape::kN;
27  static constexpr index_t KPerBlock = BlockGemmShape::kK;
28 
31  [[maybe_unused]] index_t N) noexcept;
32 
40  CK_TILE_HOST static auto
41  GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
42  {
43  const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
44  const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
45  return dim3(GridDimX, GridDimY, 1);
46  }
47 
54  CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
55  {
56  return integer_divide_ceil(K, KPerBlock);
57  }
58 
73  CK_TILE_DEVICE static auto
74  GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple<index_t, index_t>
75  {
76  const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx);
77  const index_t iN = __builtin_amdgcn_readfirstlane(blockIdy);
78  return make_tuple(iM, iN);
79  }
80 };
81 
87 template <typename BlockGemmShape_>
89 {
91 
92  static constexpr index_t MPerBlock = BlockGemmShape::kM;
93  static constexpr index_t NPerBlock = BlockGemmShape::kN;
94  static constexpr index_t KPerBlock = BlockGemmShape::kK;
95 
97 
105  {
106  N_ = N;
107  }
108 
116  CK_TILE_HOST_DEVICE static auto
117  GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
118  {
119  const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
120  const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
121  return GridDimX * GridDimY;
122  }
123 
130  CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
131  {
132  return integer_divide_ceil(K, KPerBlock);
133  }
134 
141  CK_TILE_DEVICE static auto
143  {
144  const index_t NBlocks = integer_divide_ceil(N_, NPerBlock);
145 
146  const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlocks);
147  const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - iM * NBlocks);
148  return make_tuple(iM, iN);
149  }
150 
151  private:
152  CK_TILE_DEVICE static index_t N_;
153 };
154 
159 template <typename, typename = void>
161 {
162 };
163 
169 template <typename T>
170 struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIndex(1))>>
172 {
173 };
174 
181 template <typename TilePartitioner,
182  typename = typename std::enable_if_t<HasFnOneArgImpl<TilePartitioner>{}>>
184 {
192  [[nodiscard]] CK_TILE_DEVICE static auto GetOffsetedTileIndex(
193  index_t block_start, index_t M, index_t N) noexcept -> const tuple<index_t, index_t>
194  {
195  const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start);
196  return make_tuple(iM, iN);
197  }
198 
207  [[nodiscard]] CK_TILE_DEVICE static auto
208  GetOffsetedTileIndex(index_t block_start, index_t M, index_t N, index_t block_idx) noexcept
209  -> const tuple<index_t, index_t>
210  {
211  const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(block_idx - block_start);
212  return make_tuple(iM, iN);
213  }
214 };
215 
227 template <typename BlockGemmShapeType, index_t GroupNum, index_t M01>
229 {
231 
232  static constexpr index_t MPerBlock = BlockGemmShape::kM;
233  static constexpr index_t NPerBlock = BlockGemmShape::kN;
234  static constexpr index_t KPerBlock = BlockGemmShape::kK;
235 
238  : M(M_), N(N_)
239  {
240  }
241 
249  CK_TILE_HOST_DEVICE static auto
250  GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
251  {
252  const index_t GridDimX = integer_divide_ceil(M, MPerBlock);
253  const index_t GridDimY = integer_divide_ceil(N, NPerBlock);
254  return GridDimX * GridDimY;
255  }
256 
263  CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
264  {
265  return integer_divide_ceil(K, KPerBlock);
266  }
267 
274  CK_TILE_DEVICE auto
275  GetOutputTileIndex(index_t block_1d_id) noexcept -> const tuple<index_t, index_t>
276  {
277  const auto M0 = integer_divide_ceil(M, MPerBlock);
278  const auto N0 = integer_divide_ceil(N, NPerBlock);
279 
280  if(M0 == 1)
281  {
282  return make_tuple(0, block_1d_id);
283  }
284  else if(N0 == 1)
285  {
286  return make_tuple(block_1d_id, 0);
287  }
288  // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
289  else
290  {
291  const auto group_size = integer_divide_ceil(M0 * N0, GroupNum);
292  const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
293  const auto group_id_y = block_1d_id / GroupNum;
294  const auto group_id_x = block_1d_id - group_id_y * GroupNum;
295  const auto remap_block_1d_id =
296  group_id_x <= big_group_num
297  ? group_id_x * group_size + group_id_y
298  : group_id_x * group_size + big_group_num - group_id_x + group_id_y;
299 
300  const index_t idx_M0 = remap_block_1d_id / N0;
301  const index_t idx_N0 = remap_block_1d_id - idx_M0 * N0;
302 
303  const index_t M0_tmp = M0 / M01;
304  const index_t M0_mod_M01 = M0 - M0_tmp * M01;
305 
306  const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01;
307 
308  const index_t idx_M00 = idx_M0 / M01;
309  const index_t idx_M01 = idx_M0 - idx_M00 * M01;
310  const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
311 
356  const index_t N_out = idx_N0_M01_local / M01_adapt;
357  const index_t idx_loc_mod_M01 = idx_N0_M01_local - N_out * M01_adapt;
358 
359  return make_tuple(idx_loc_mod_M01 + idx_M00 * M01, N_out);
360  }
361  }
362 
363  private:
364  index_t M;
365  index_t N;
366 };
367 
384 template <typename BlockGemmShapeType,
386  uint32_t TileSwizzleSubM = 8>
388 {
389  using BlockGemmShape = BlockGemmShapeType;
390 
391  static constexpr uint32_t MPerBlock = BlockGemmShape::kM;
392  static constexpr uint32_t NPerBlock = BlockGemmShape::kN;
393  static constexpr uint32_t KPerBlock = BlockGemmShape::kK;
394 
396 
401  uint32_t N,
402  uint32_t K,
403  uint32_t num_cu,
404  uint32_t occupancy,
405  uint32_t sk_blocks = 0xffffffff) noexcept
406  : M_(M), N_(N), K_(K)
407  {
408  num_tile_m_ = integer_divide_ceil(M, MPerBlock);
409  num_tile_n_ = integer_divide_ceil(N, NPerBlock);
410  num_tile_k_ = integer_divide_ceil(K, KPerBlock);
411 
412  constexpr uint32_t min_k_iters_per_sk_block = 2;
413  uint32_t num_tiles = num_tile_m_ * num_tile_n_;
414  k_iters_per_tile = mdiv(num_tile_k_);
415 
416  // one cu can hold one wg at one time, from the whole cZ's point of view
417  // if number of wg is same as num_cu, we call it 1 dispatch
418  // if number of wg is 2x num_cu, we call it 2 dispatches.
419  // one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
420  // dispatch)
421  //
422  const uint32_t full_dispatches = num_tiles / num_cu;
423  const uint32_t full_dispatch_tiles = full_dispatches * num_cu;
424  const uint32_t partial_dispatch_tiles = num_tiles - full_dispatch_tiles;
425 
426  uint32_t sk_occupancy = occupancy;
427  uint32_t dp_tiles = full_dispatch_tiles;
428  uint32_t sk_tiles = partial_dispatch_tiles;
429 
430  if(full_dispatches < occupancy)
431  {
432  // in this case, we allocate all blocks as sk blocks
433  // sk_occupancy = occupancy - full_dispatches;
434  sk_occupancy = 1;
435  dp_tiles = full_dispatch_tiles;
436  sk_tiles = partial_dispatch_tiles;
437  }
438  else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
439  {
440  // e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
441  // occupancy = 3, full_dispatches = 5, 8, 11 ...
442  // occupancy = 4, full_dispatches = 7, 11 ...
443  sk_occupancy = 1; // left 1 slot for sk occupancy
444  dp_tiles = full_dispatch_tiles;
445  sk_tiles = partial_dispatch_tiles;
446  }
447  else
448  {
449  // otherwise, we reduce 1 dispatch from dp, together with partial dispatch,
450  // to construct sk dispatch
451  sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
452  dp_tiles = full_dispatch_tiles - num_cu;
453  sk_tiles = partial_dispatch_tiles + num_cu;
454  }
455 
456  // uint32_t dp_iters_per_block = k_iters_per_tile.get();
457  uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
458  uint32_t dp_num_blocks = 0;
459 
460  {
461  const uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
462  const uint32_t max_sk_tiles =
463  (sk_tiles >= num_cu) ? num_cu * sk_occupancy
464  : min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
465 
466  // if use dp for sk-block, how many iters do we need
467  const uint32_t dp_for_sk_iters = k_iters_per_tile.get();
468 
469  uint32_t best_sk_score =
470  std::numeric_limits<int>::max(); // we need to find the smallest sk iters
471  for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
472  tentative_sk_blocks++)
473  {
474  const uint32_t tentative_sk_iters_per_block =
475  (sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
476  const uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
477  const uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
478 
479  // the more sk_blocks_per_tile, the worse the overhead
480  uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
481  if(tentative_sk_blocks % sk_tiles != 0)
482  {
483  // penalty for uneven divide
484  cross_sk_blocks_overhead +=
485  sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
486  }
487 
488  const uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
489 
490  if(tentative_sk_score < best_sk_score)
491  {
492  best_sk_score = tentative_sk_score;
493  sk_num_blocks = tentative_sk_blocks;
494  }
495  }
496 
497  if(best_sk_score >= dp_for_sk_iters)
498  {
499  sk_num_blocks = 0;
500  }
501 
502  // give a chance to control num of sk blocks
503  sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
504 
505  if(sk_num_blocks == 0)
506  {
507  sk_num_big_blocks = 0;
509 
510  dp_num_blocks = num_tiles; // all tile to be dp block
511  dp_start_block_idx = 0;
512  sk_total_iters = 0; // clear this tiles
513  }
514  else
515  {
516  // k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
517  // we need to decide how many iters for each sk block
518  // let m = k_iters_per_sk_block
519  // some of the sk block (little) will cover m iters, some (big) will cover m+1
520  // we have
521  // 1) l + b = sk_blocks
522  // 2) l * m + b * (m + 1) = sk_total_iters
523  // => (l + b) * m + b = sk_total_iters
524  // => sk_blocks * m + b = sk_total_iters
525  // => b = sk_total_iters - m * sk_blocks
526  // NOTE: big could be zero
527  const uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
528  sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
529  k_iters_per_big_block = k_iters_per_sk_block + 1;
530 
531  dp_num_blocks = dp_tiles;
532  dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
533  }
534  }
535  n_tiles = mdiv2(num_tile_n_);
537 
538  if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
539  {
540  const uint32_t upper_big = lcm(k_iters_per_big_block, k_iters_per_tile.get());
541  const uint32_t upper_little = lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
542  equiv_tiles_big = mdiv(upper_big / k_iters_per_tile.get());
543  equiv_tiles_little = mdiv(upper_little / k_iters_per_tile.get());
544  }
545  }
546 
550  CK_TILE_HOST auto GridSize() const noexcept -> dim3
551  {
552  if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
553  {
554  return dim3(reduction_start_block_idx + GetSkTiles(), 1, 1);
555  }
556  else
557  return dim3(reduction_start_block_idx, 1, 1);
558  }
559 
563  CK_TILE_HOST_DEVICE static auto GetLoopNum(uint32_t K) noexcept -> uint32_t
564  {
565  return integer_divide_ceil(K, KPerBlock); // Stream-K processes one K-slice at a time
566  }
567 
571  CK_TILE_DEVICE auto
573  {
574  uint32_t m_tile_idx, n_tile_idx;
575  n_tiles.divmod(tile_idx, num_tile_n_, m_tile_idx, n_tile_idx);
576 
577  // swizzle tile
578 
579  uint32_t tile_swizzle_sub_m_rem = num_tile_m_ % TileSwizzleSubM;
580 
581  const auto sub_m_adapt = (m_tile_idx < (num_tile_m_ - tile_swizzle_sub_m_rem))
582  ? TileSwizzleSubM
583  : tile_swizzle_sub_m_rem;
584 
585  uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
586  m_tile_idx_sub0 = m_tile_idx / TileSwizzleSubM;
587  m_tile_idx_sub1 = m_tile_idx % TileSwizzleSubM;
588 
589  uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * num_tile_n_;
590 
591  uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
592 
593  n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
594  m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
595  return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * TileSwizzleSubM,
596  n_tile_idx_with_adapt);
597  }
598 
602  CK_TILE_DEVICE void
603  GetBlockItr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const noexcept
604  {
605  if(block_idx < sk_num_big_blocks)
606  {
607  iter_start = block_idx * k_iters_per_big_block;
608  iter_end = iter_start + k_iters_per_big_block;
609  }
610  else if(block_idx < sk_num_blocks)
611  {
612  iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
613  (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
614  iter_end = iter_start + (k_iters_per_big_block - 1);
615  }
616  else if(block_idx >= dp_start_block_idx)
617  {
618  uint32_t sk_total_iters = GetSkTotalIters();
619  uint32_t dp_iters_per_block = k_iters_per_tile.get();
620  iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
621  iter_end = iter_start + dp_iters_per_block;
622  }
623  }
624 
629  {
632  return sk_total_iters;
633  }
634 
639  {
640  // tiles for sk
641  uint32_t sk_total_iters = GetSkTotalIters();
642  return k_iters_per_tile.div(sk_total_iters);
643  }
644 
649  uint32_t iter_end,
650  uint32_t total_iter_length) const noexcept
651  {
652  uint32_t iter_length_mod, iter_length_quo /*unused*/;
653  k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
654  uint32_t total_iter_length_val = static_cast<uint32_t>(total_iter_length);
655  uint32_t current_iter_length =
656  min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod,
657  total_iter_length_val);
658  return current_iter_length;
659  }
660 
665  {
666  return k_iters_per_tile.div(iter);
667  }
668 
672  CK_TILE_DEVICE void
673  GetTileIdxWithOffset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const noexcept
674  {
675  uint32_t tile_idx_val = static_cast<uint32_t>(tile_idx);
676  uint32_t iter_offset_val = static_cast<uint32_t>(iter_offset);
677  k_iters_per_tile.divmod(iter, tile_idx_val, iter_offset_val);
678  }
679 
684  {
685  static constexpr uint32_t alignment = 128;
686  uint32_t acc_buffer_bytes =
687  MPerBlock * NPerBlock * GetTotalAccBuffers() * acc_element_bytes;
688  return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
689  }
690 
695  {
696  return GetSkTiles() * sizeof(uint32_t);
697  }
698 
702  CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSize(uint32_t acc_element_bytes) const noexcept
703  {
704  return GetWorkSpaceSizeForAcc(acc_element_bytes) + GetWorkSpaceSizeForSemaphore();
705  }
706 
711  const mdiv& equiv_tiles_) const noexcept
712  {
713  uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
714  uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1;
715  uint32_t quo_, rem_;
716  equiv_tiles_.divmod(tile_idx_, quo_, rem_);
717  return quo_ * max_equiv_tiles_ + rem_;
718  }
719 
724  uint32_t iters_per_sk_block_) const noexcept
725  {
726  return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
727  1);
728  }
729 
734  {
735  uint32_t tiles_cover_big_blocks =
737  uint32_t tiles_cover_little_blocks =
739 
740  uint32_t total_intersec_big = GetTileIntersections(tiles_cover_big_blocks, equiv_tiles_big);
741  uint32_t total_intersec_little =
742  GetTileIntersections(tiles_cover_little_blocks, equiv_tiles_little);
743 
744  return sk_num_blocks + total_intersec_big + total_intersec_little;
745  }
746 
751  {
752  uint32_t tiles_cover_big_blocks =
754  if(tile_idx_ < tiles_cover_big_blocks)
755  {
756  uint32_t touched_sk_blocks =
757  (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
759  uint32_t current_intersec = GetTileIntersections(tile_idx_, equiv_tiles_big);
760  return touched_sk_blocks + current_intersec;
761  }
762  else
763  {
764  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
765  uint32_t tile_idx_little_reverse = GetSkTiles() - tile_idx_;
766  uint32_t touched_sk_blocks =
767  (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
768  iters_per_little_sk_block;
769  uint32_t current_intersec =
770  GetTileIntersections(tile_idx_little_reverse, equiv_tiles_little);
771  return GetTotalAccBuffers() - (touched_sk_blocks + current_intersec);
772  }
773  }
774 
779  {
780  uint32_t iters_per_big_sk_block = k_iters_per_big_block;
781  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
782  if(block_idx_ < sk_num_big_blocks)
783  {
784  uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
785  k_iters_per_tile.get() - 1);
786  uint32_t current_intersec = GetTileIntersections(touched_tiles, equiv_tiles_big);
787  return block_idx_ + current_intersec;
788  }
789  else
790  {
791  uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
792  uint32_t touched_tiles = k_iters_per_tile.div(
793  block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
794  uint32_t current_intersec = GetTileIntersections(touched_tiles, equiv_tiles_little);
795  return GetTotalAccBuffers() - (block_idx_little_reverse + current_intersec);
796  }
797  }
798 
799  // Getters for problem dimensions
800  CK_TILE_HOST_DEVICE uint32_t GetNumTileM() const noexcept { return num_tile_m_; }
801  CK_TILE_HOST_DEVICE uint32_t GetNumTileN() const noexcept { return num_tile_n_; }
802  CK_TILE_HOST_DEVICE uint32_t GetNumTileK() const noexcept { return num_tile_k_; }
803 
811  mdiv equiv_tiles_big; // for reduction
812  mdiv equiv_tiles_little; // for reduction
813 
814  private:
815  uint32_t M_, N_, K_;
816  uint32_t num_tile_m_, num_tile_n_, num_tile_k_;
817 };
818 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
StreamKReductionStrategy
Definition: streamk_common.hpp:10
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
Y constexpr CK_TILE_HOST_DEVICE auto lcm(X x, Y y)
Definition: math.hpp:314
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
@ Atomic
Definition: block_to_ctile_map.hpp:1011
@ Reduction
Definition: block_to_ctile_map.hpp:1012
bool_constant< false > false_type
Definition: integral_constant.hpp:63
bool_constant< true > true_type
Definition: integral_constant.hpp:62
unsigned int uint32_t
Definition: stdint.h:126
Class mapping 1D block index into 2D output tile space.
Definition: gemm_tile_partitioner.hpp:229
static constexpr index_t MPerBlock
Definition: gemm_tile_partitioner.hpp:232
static CK_TILE_HOST_DEVICE auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> index_t
Calculates GEMM kernel grid size.
Definition: gemm_tile_partitioner.hpp:250
static constexpr index_t KPerBlock
Definition: gemm_tile_partitioner.hpp:234
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition: gemm_tile_partitioner.hpp:263
CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner() noexcept=delete
remove_cvref_t< BlockGemmShapeType > BlockGemmShape
Definition: gemm_tile_partitioner.hpp:230
static constexpr index_t NPerBlock
Definition: gemm_tile_partitioner.hpp:233
CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept -> const tuple< index_t, index_t >
Calculate workgroup 1D index mapping into 2D output C-tile space.
Definition: gemm_tile_partitioner.hpp:275
Class providing 1D WGP index mapping into 2D output C-tile space.
Definition: gemm_tile_partitioner.hpp:89
CK_TILE_HOST_DEVICE GemmTile1DPartitioner() noexcept=delete
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition: gemm_tile_partitioner.hpp:130
static CK_TILE_DEVICE auto GetOutputTileIndex(index_t blockIdx) noexcept -> const tuple< index_t, index_t >
Calculate workgroup 1D index mapping into 2D output C-tile space.
Definition: gemm_tile_partitioner.hpp:142
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_tile_partitioner.hpp:90
static constexpr index_t MPerBlock
Definition: gemm_tile_partitioner.hpp:92
static constexpr index_t NPerBlock
Definition: gemm_tile_partitioner.hpp:93
static CK_TILE_HOST_DEVICE auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> index_t
Calculates GEMM kernel grid size.
Definition: gemm_tile_partitioner.hpp:117
static constexpr index_t KPerBlock
Definition: gemm_tile_partitioner.hpp:94
Class providing 2D workgroup index mapping into 2D output GEMM C-tile space.
Definition: gemm_tile_partitioner.hpp:22
static CK_TILE_DEVICE auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple< index_t, index_t >
The function returns 2D output tile space.
Definition: gemm_tile_partitioner.hpp:74
static CK_TILE_HOST auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> dim3
Calculates GEMM kernel grid size.
Definition: gemm_tile_partitioner.hpp:41
remove_cvref_t< BlockGemmShapeType > BlockGemmShape
Definition: gemm_tile_partitioner.hpp:23
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition: gemm_tile_partitioner.hpp:54
static constexpr index_t NPerBlock
Definition: gemm_tile_partitioner.hpp:26
static constexpr index_t KPerBlock
Definition: gemm_tile_partitioner.hpp:27
static constexpr index_t MPerBlock
Definition: gemm_tile_partitioner.hpp:25
CK_TILE_HOST_DEVICE GemmTile2DPartitioner() noexcept=delete
GemmTile1DPartitioner::GetOutputTileIndex's std::false specialization, checking expression validity i...
Definition: gemm_tile_partitioner.hpp:161
Struct used to calculate offseted tile indexes.
Definition: gemm_tile_partitioner.hpp:184
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N, index_t block_idx) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from a given block index.
Definition: gemm_tile_partitioner.hpp:208
Stream-K tile partitioner that dynamically balances work across workgroups.
Definition: gemm_tile_partitioner.hpp:388
CK_TILE_HOST_DEVICE uint32_t GetTileIntersections(uint32_t tiles_, const mdiv &equiv_tiles_) const noexcept
Get location of intersection of tiles for reduction.
Definition: gemm_tile_partitioner.hpp:710
CK_TILE_HOST_DEVICE uint32_t GetNumTileK() const noexcept
Definition: gemm_tile_partitioner.hpp:802
uint32_t k_iters_per_big_block
Definition: gemm_tile_partitioner.hpp:808
CK_TILE_HOST_DEVICE uint32_t GetSkTotalIters() const noexcept
Get total number of iterations for sk tiles.
Definition: gemm_tile_partitioner.hpp:628
CK_TILE_HOST_DEVICE StreamKTilePartitioner() noexcept=delete
static constexpr uint32_t MPerBlock
Definition: gemm_tile_partitioner.hpp:391
CK_TILE_HOST_DEVICE uint32_t GetNumTileM() const noexcept
Definition: gemm_tile_partitioner.hpp:800
CK_TILE_DEVICE uint32_t GetAccBufferOffsetFromBlock(uint32_t block_idx_) const noexcept
Calculate offset based on block_idx index for big/little streamk blocks.
Definition: gemm_tile_partitioner.hpp:778
CK_TILE_DEVICE void GetTileIdxWithOffset(uint32_t iter, uint32_t &tile_idx, uint32_t &iter_offset) const noexcept
Get index of tile during a specified iteration.
Definition: gemm_tile_partitioner.hpp:673
uint32_t sk_num_blocks
Definition: gemm_tile_partitioner.hpp:804
mdiv equiv_tiles_little
Definition: gemm_tile_partitioner.hpp:812
CK_TILE_DEVICE uint32_t GetAccBufferOffsetFromTile(uint32_t tile_idx_) const noexcept
Calculate offset based on tile index for big/little tiles.
Definition: gemm_tile_partitioner.hpp:750
mdiv2 n_tiles
Definition: gemm_tile_partitioner.hpp:809
CK_TILE_HOST_DEVICE uint32_t GetNumTileN() const noexcept
Definition: gemm_tile_partitioner.hpp:801
CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSize(uint32_t acc_element_bytes) const noexcept
Calculates the total buffer space needed for accumulation and the semaphore.
Definition: gemm_tile_partitioner.hpp:702
static constexpr uint32_t NPerBlock
Definition: gemm_tile_partitioner.hpp:392
static constexpr uint32_t KPerBlock
Definition: gemm_tile_partitioner.hpp:393
CK_TILE_HOST_DEVICE uint32_t GetTilesCoverSkBlock(uint32_t num_sk_blocks_, uint32_t iters_per_sk_block_) const noexcept
Calculate the number of tiles needed for the number of sk blocks.
Definition: gemm_tile_partitioner.hpp:723
static CK_TILE_HOST_DEVICE auto GetLoopNum(uint32_t K) noexcept -> uint32_t
Calculate number of loop iterations over K dimension for given work unit.
Definition: gemm_tile_partitioner.hpp:563
mdiv equiv_tiles_big
Definition: gemm_tile_partitioner.hpp:811
CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSizeForSemaphore() const noexcept
Calculates the buffer space needed for the semaphore.
Definition: gemm_tile_partitioner.hpp:694
CK_TILE_HOST auto GridSize() const noexcept -> dim3
Calculate optimal grid size for Stream-K.
Definition: gemm_tile_partitioner.hpp:550
CK_TILE_HOST_DEVICE uint32_t GetSkTiles() const noexcept
Get total number of sk tiles.
Definition: gemm_tile_partitioner.hpp:638
CK_TILE_DEVICE auto GetOutputTileIndex(uint32_t tile_idx) const noexcept -> tuple< uint32_t, uint32_t >
Get output tile index for standard 2D mapping (compatibility)
Definition: gemm_tile_partitioner.hpp:572
uint32_t sk_num_big_blocks
Definition: gemm_tile_partitioner.hpp:805
uint32_t dp_start_block_idx
Definition: gemm_tile_partitioner.hpp:806
CK_TILE_DEVICE void GetBlockItr(uint32_t block_idx, uint32_t &iter_start, uint32_t &iter_end) const noexcept
Get work range for a given block ID.
Definition: gemm_tile_partitioner.hpp:603
CK_TILE_DEVICE uint32_t GetCurrentIterLength(uint32_t iter_start, uint32_t iter_end, uint32_t total_iter_length) const noexcept
Get length of loop iterations for stream-k loop.
Definition: gemm_tile_partitioner.hpp:648
mdiv k_iters_per_tile
Definition: gemm_tile_partitioner.hpp:810
CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSizeForAcc(uint32_t acc_element_bytes) const noexcept
Calculates the buffer space needed for accumulation.
Definition: gemm_tile_partitioner.hpp:683
CK_TILE_HOST_DEVICE uint32_t GetTotalAccBuffers() const noexcept
Calculate the amount of total accumulation buffers required for stream-k.
Definition: gemm_tile_partitioner.hpp:733
BlockGemmShapeType BlockGemmShape
Definition: gemm_tile_partitioner.hpp:389
CK_TILE_DEVICE uint32_t GetTileIdx(uint32_t iter) const noexcept
Get index of tile during a specified iteration.
Definition: gemm_tile_partitioner.hpp:664
uint32_t reduction_start_block_idx
Definition: gemm_tile_partitioner.hpp:807
Definition: magic_div.hpp:228
CK_TILE_HOST_DEVICE void divmod(uint32_t dividend_, uint32_t divisor_, uint32_t &quotient_, uint32_t &remainder_) const
Definition: magic_div.hpp:250
Definition: magic_div.hpp:186
CK_TILE_HOST_DEVICE uint32_t get() const
Definition: magic_div.hpp:224
CK_TILE_HOST_DEVICE void divmod(uint32_t dividend_, uint32_t &quotient_, uint32_t &remainder_) const
Definition: magic_div.hpp:218
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
Definition: magic_div.hpp:212
Definition: tuple.hpp:192