/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp Source File
block_to_ctile_map.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/utility/math.hpp"
7 #include "ck/utility/number.hpp"
8 #include "ck/utility/tuple.hpp"
11 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
12 #include <limits>
13 #include <stdlib.h>
14 #endif
15 
16 namespace ck {
17 
18 // Rows of column-vectors
19 template <index_t MPerBlock,
20  index_t NPerBlock,
21  typename CGridDesc_M_N,
22  bool DeviceCTileIndexCheck = false>
24 {
25  static constexpr auto I0 = Number<0>{};
26  static constexpr auto I1 = Number<1>{};
27  static constexpr auto I2 = Number<2>{};
28  static constexpr auto I3 = Number<3>{};
29 
30  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01() = default;
31 
32  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n,
33  index_t M01 = 1)
34  : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01))
35  {
36  }
37 
38  __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
39  {
40  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
41  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
42 
43  const auto M00 = math::integer_divide_ceil(M0, M01_);
44 
45  const index_t grid_size = M00 * M01_ * N0;
46 
47  return grid_size;
48  }
49 
50  template <typename TopIdx>
51  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
52  {
53  return underlying_map_.CalculateBottomIndex(idx_top);
54  }
55 
56  template <typename CTileIdx, typename CTileDim>
57  __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& c_tile_idx,
58  const CTileDim& c_tile_dim) const
59  {
60  if constexpr(DeviceCTileIndexCheck)
61  return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
62  else
63  return true;
64  }
65 
66  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
67  {
68  if constexpr(DeviceCTileIndexCheck)
69  return true; // validity check moved to kernel
70 
71  const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
72  if(M0 % M01_ == 0)
73  {
74  return true;
75  }
76  else
77  {
78  return false;
79  }
80  }
81 
82  private:
83  __host__ __device__ static constexpr auto
84  GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01)
85  {
86  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
87  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
88 
89  const auto M00 = math::integer_divide_ceil(M0, M01);
90 
91  const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
95  make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
96  make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
97 
98  const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor(
99  make_tuple(make_merge_transform(make_tuple(1, M00, N0, M01))),
100  make_tuple(Sequence<0, 1, 2, 3>{}),
101  make_tuple(Sequence<0>{}));
102 
103  const auto cblockid_to_m0_n0_block_cluster_adaptor =
104  chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor,
105  cblockid_to_m00_n0_m01_block_cluster_adaptor);
106 
107  return cblockid_to_m0_n0_block_cluster_adaptor;
108  }
109 
110  index_t M01_;
111  using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1));
112  UnderlyingMap underlying_map_;
113 };
114 
115 // Rows of column-vectors
116 // This C-tile map dynamically adjusts M01 when C-tile index is out of range
117 template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
118 struct BlockToCTileMap_M00_N0_M01Adapt;
119 
120 template <index_t MPerBlock, index_t NPerBlock>
121 struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
122 {
123  static constexpr auto I0 = Number<0>{};
124  static constexpr auto I1 = Number<1>{};
125 
126  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
127 
128  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
129  const BlockToCTileMap_M00_N0_M01Adapt&) = default;
130  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
131  BlockToCTileMap_M00_N0_M01Adapt&&) = default;
132  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
134  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
136 
137  __host__
138  __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
139  : M_(M), N_(N), M01_(M01)
140  {
141 #if 0
142  if(get_thread_global_1d_id()==0){
143  printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
144  }
145 #endif
146  }
147 
148  template <typename CGridDesc_M_N>
149  __host__
150  __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
151  index_t M01 = 8)
153  c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
154  {
155  }
156 
157  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
158  {
159  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
160  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
161 
162  return M0 * N0;
163  }
164 
165  template <typename CGridDesc_M_N>
166  __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
167  {
168  return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
169  }
170 
171  template <typename CGridDesc_M_N>
172  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
173  {
174  return true;
175  }
176 
177  template <typename TopIdx>
178  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
179  {
180  auto block_1d_id = idx_top[I0];
181 
182  const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
183  const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
184 
185  block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
186 
187  index_t idx_N0 = block_1d_id % N0;
188  index_t idx_M0 = block_1d_id / N0;
189 
190  const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
191 
192  index_t idx_M00 = idx_M0 / M01_;
193  index_t idx_M01 = idx_M0 % M01_;
194  index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
195 
240  return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
241  idx_N0_M01_local / M01_adapt);
242  }
243 
244  template <typename CTileIdx, typename CTileDim>
245  __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
246  const CTileDim& /* c_tile_dim */) const
247  {
248  return true; // always valid provided that user gets grid size from CalculateGridSize()
249  }
250 
251  private:
252  index_t M_;
253  index_t N_;
254  index_t M01_;
255 };
256 
257 // keep the redundant type argument for backward compatibility
258 template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
260 {
262  BlockToCTileMap_M00_N0_M01Adapt;
263 };
264 
265 // Grouped Rows of column-vectors WGP mapping
266 // Optimized for gfx94x-like multipe-die chip
267 
268 template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
270 {
271  static constexpr auto I0 = Number<0>{};
272  static constexpr auto I1 = Number<1>{};
273 
274  __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default;
276  index_t N,
277  index_t M01 = 8)
278  : M_(M), N_(N), M01_(M01)
279  {
280  }
281 
282  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
283  {
284  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
285  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
286 
287  return M0 * N0;
288  }
289 
290  template <typename CGridDesc_M_N>
291  __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
292  {
293  return true;
294  }
295 
296  template <typename TopIdx>
297  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
298  {
299  auto block_1d_id = idx_top[I0];
300 
301  const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
302  const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
303 
304  if(M0 == 1)
305  {
306  return make_tuple(0, block_1d_id);
307  }
308  else if(N0 == 1)
309  {
310  return make_tuple(block_1d_id, 0);
311  }
312  // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
313  else
314  {
315  const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
316  const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
317  auto group_id_x = block_1d_id % GroupNum;
318  auto group_id_y = block_1d_id / GroupNum;
319  auto remap_block_1d_id =
320  group_id_x <= big_group_num
321  ? group_id_x * group_size + group_id_y
322  : group_id_x * group_size + big_group_num - group_id_x + group_id_y;
323 
324  index_t idx_N0 = remap_block_1d_id % N0;
325  index_t idx_M0 = remap_block_1d_id / N0;
326 
327  const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
328 
329  index_t idx_M00 = idx_M0 / M01_;
330  index_t idx_M01 = idx_M0 % M01_;
331  index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
332 
377  return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
378  idx_N0_M01_local / M01_adapt);
379  }
380  }
381 
382  template <typename CTileIdx, typename CTileDim>
383  __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
384  const CTileDim& /* c_tile_dim */) const
385  {
386  return true; // always valid provided that user gets grid size from CalculateGridSize()
387  }
388 
389  private:
390  index_t M_;
391  index_t N_;
392  index_t M01_;
393 };
394 
395 // columns of row-vectors
396 // This C-tile map dynamically adjusts N01 when C-tile index is out of range
397 template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
399 
400 template <index_t MPerBlock, index_t NPerBlock>
401 struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
402 {
403  static constexpr auto I0 = Number<0>{};
404  static constexpr auto I1 = Number<1>{};
405 
406  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt() = default;
407 
409  default;
411  default;
412  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
414  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
416 
417  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01 = 8)
418  : M_(M), N_(N), N01_(N01)
419  {
420 #if 0
421  if(get_thread_global_1d_id()==0){
422  printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_);
423  }
424 #endif
425  }
426 
427  template <typename CGridDesc_M_N>
428  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
429  index_t N01 = 8)
431  c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), N01)
432  {
433  }
434 
435  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
436  {
437  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
438  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
439 
440  return M0 * N0;
441  }
442 
443  template <typename CGridDesc_M_N>
444  __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
445  {
446  return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
447  }
448 
449  template <typename CGridDesc_M_N>
450  __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
451  {
452  return true;
453  }
454 
455  template <typename TopIdx>
456  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
457  {
458  auto block_1d_id = idx_top[I0];
459 
460  const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
461  const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
462 
463  block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
464 
465  index_t idx_M0 = block_1d_id % M0;
466  index_t idx_N0 = block_1d_id / M0;
467 
468  const auto N01_adapt = (idx_N0 < N0 - N0 % N01_) ? N01_ : N0 % N01_;
469 
470  index_t idx_N00 = idx_N0 / N01_;
471  index_t idx_N01 = idx_N0 % N01_;
472  index_t idx_M0_N01_local = idx_M0 + idx_N01 * M0;
473 
519  return make_tuple(idx_M0_N01_local / N01_adapt,
520  idx_M0_N01_local % N01_adapt + idx_N00 * N01_);
521  }
522 
523  template <typename CTileIdx, typename CTileDim>
524  __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
525  const CTileDim& /* c_tile_dim */) const
526  {
527  return true; // always valid provided that user gets grid size from CalculateGridSize()
528  }
529 
530  private:
531  index_t M_;
532  index_t N_;
533  index_t N01_;
534 };
535 
536 // 2D slices of column-vectors in 3D space
537 // This C-tile map dynamically adjusts M01 when C-tile index is out of range
538 template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
540 {
541  static constexpr auto I0 = Number<0>{};
542  static constexpr auto I1 = Number<1>{};
543  static constexpr auto I2 = Number<2>{};
544  static constexpr auto I3 = Number<3>{};
545 
546  __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt() = default;
547 
548  __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
549  index_t M01 = 8,
550  index_t KSplit = 1)
551  : M01_(M01), KSplit_(KSplit), c_grid_desc_m_n_(c_grid_desc_m_n)
552  {
553  }
554 
555  __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
556  {
557  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
558  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
559 
560  const index_t grid_size = M0 * N0 * KSplit_;
561 
562  return grid_size;
563  }
564 
565  template <typename TopIdx>
566  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
567  {
568  auto block_1d_id = idx_top[I0];
569 
570  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
571  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
572 
573  block_1d_id = block_1d_id % (M0 * N0 * KSplit_); // hide groups
574 
575  const index_t idx_ksplit = block_1d_id / (M0 * N0);
576  block_1d_id = block_1d_id % (M0 * N0);
577 
578  index_t idx_N0 = block_1d_id % N0;
579  index_t idx_M0 = block_1d_id / N0;
580 
581  const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
582 
583  index_t idx_M00 = idx_M0 / M01_;
584  index_t idx_M01 = idx_M0 % M01_;
585  index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
586 
587  return make_tuple(idx_ksplit,
588  idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
589  idx_N0_M01_local / M01_adapt);
590  }
591 
592  template <typename CTileIdx, typename CTileDim>
593  __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
594  const CTileDim& /* c_tile_dim */) const
595  {
596  return true; // always valid provided that user gets grid size from CalculateGridSize()
597  }
598 
599  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
600  {
601  return true;
602  }
603 
604  private:
605  index_t M01_;
606  index_t KSplit_;
607  CGridDesc_M_N c_grid_desc_m_n_;
608 };
609 
610 // Blocks of row-vectors
611 template <index_t MPerBlock,
612  index_t NPerBlock,
613  typename CGridDesc_M_N,
614  bool DeviceCTileIndexCheck = false>
616 {
617  static constexpr auto I0 = Number<0>{};
618  static constexpr auto I1 = Number<1>{};
619  static constexpr auto I2 = Number<2>{};
620  static constexpr auto I3 = Number<3>{};
621 
622  __host__ __device__ BlockToCTileMap_M00_N00_M01_N01() = default;
623 
624  __host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n,
625  index_t M01 = 1,
626  index_t N01 = 1)
627  : M01_(M01), N01_(N01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01))
628  {
629  }
630 
631  __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
632  {
633  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
634  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
635 
636  const auto M00 = math::integer_divide_ceil(M0, M01_);
637  const auto N00 = math::integer_divide_ceil(N0, N01_);
638 
639  const index_t grid_size = M00 * M01_ * N00 * N01_;
640 
641  return grid_size;
642  }
643 
644  template <typename TopIdx>
645  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
646  {
647  return underlying_map_.CalculateBottomIndex(idx_top);
648  }
649 
650  template <typename CTileIdx, typename CTileDim>
651  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
652  const CTileDim& c_tile_dim) const
653  {
654  if constexpr(DeviceCTileIndexCheck)
655  return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
656  else
657  return true;
658  }
659 
660  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
661  {
662  if constexpr(DeviceCTileIndexCheck)
663  return true; // validity check moved to kernel
664 
665  const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
666  const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
667  if(M0 % M01_ == 0 && N0 % N01_ == 0)
668  {
669  return true;
670  }
671  else
672  {
673  return false;
674  }
675  }
676 
677  private:
678  __host__ __device__ static constexpr auto
679  GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
680  {
681  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
682  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
683 
684  const auto M00 = math::integer_divide_ceil(M0, M01);
685  const auto N00 = math::integer_divide_ceil(N0, N01);
686 
687  const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
689  make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions
692  make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
693  make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
694 
695  const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
697  make_tuple(make_merge_transform(make_tuple(1, M00, N00, M01, N01))),
698  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
699  make_tuple(Sequence<0>{}));
700 
701  const auto cblockid_to_m0_n0_block_cluster_adaptor =
702  chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
703  cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
704 
705  return cblockid_to_m0_n0_block_cluster_adaptor;
706  }
707 
708  index_t M01_, N01_;
709  using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1));
710  UnderlyingMap underlying_map_;
711 };
712 
713 // 2D slices of row-vectors in 3D space
714 template <index_t MPerBlock,
715  index_t NPerBlock,
716  typename CGridDesc_M_N,
717  bool DeviceCTileIndexCheck = false>
719 {
720  static constexpr auto I0 = Number<0>{};
721  static constexpr auto I1 = Number<1>{};
722  static constexpr auto I2 = Number<2>{};
723  static constexpr auto I3 = Number<3>{};
724 
726 
727  __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n,
728  index_t M01 = 1,
729  index_t N01 = 1,
730  index_t KSplit = 1)
731  : c_grid_desc_m_n_(c_grid_desc_m_n),
732  M01_(M01),
733  N01_(N01),
734  KSplit_(KSplit),
735  underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit))
736  {
737  }
738 
739  __host__ __device__ constexpr index_t
740  CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
741  {
742  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
743  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
744 
745  const auto M00 = math::integer_divide_ceil(M0, M01_);
746  const auto N00 = math::integer_divide_ceil(N0, N01_);
747 
748  const index_t grid_size = M00 * M01_ * N00 * N01_ * KSplit_;
749 
750  return grid_size;
751  }
752 
753  template <typename TopIdx>
754  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
755  {
756  static_assert(TopIdx::Size() == 1);
757 
758  return underlying_map_.CalculateBottomIndex(
759  make_multi_index(idx_top[I0] % CalculateGridSize()));
760  }
761 
762  template <typename CTileIdx, typename CTileDim>
763  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
764  const CTileDim& c_tile_dim) const
765  {
766  if constexpr(DeviceCTileIndexCheck)
767  return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
768  else
769  return true;
770  }
771 
772  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
773  {
774  if constexpr(DeviceCTileIndexCheck)
775  return true; // validity check moved to kernel
776 
777  const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
778  const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
779  if(M0 % M01_ == 0 && N0 % N01_ == 0)
780  {
781  return true;
782  }
783  else
784  {
785  return false;
786  }
787  }
788 
789  private:
790  __device__ constexpr index_t CalculateGridSize() const
791  {
792  return CalculateGridSize(c_grid_desc_m_n_);
793  }
794 
795  __host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n,
796  index_t M01,
797  index_t N01,
798  index_t KSplit)
799  {
800  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
801  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
802 
803  const auto M00 = math::integer_divide_ceil(M0, M01);
804  const auto N00 = math::integer_divide_ceil(N0, N01);
805 
806  const auto ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
811  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
812  make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
813 
814  const auto c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor =
816  make_tuple(make_merge_transform(make_tuple(KSplit, M00, N00, M01, N01))),
817  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
818  make_tuple(Sequence<0>{}));
819 
820  const auto c_blockid_to_ksplit_m0_n0_block_cluster_adaptor =
821  chain_tensor_adaptors(ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
822  c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor);
823 
824  return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor;
825  }
826 
827  CGridDesc_M_N c_grid_desc_m_n_;
828  index_t M01_, N01_, KSplit_;
829  using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1));
830  UnderlyingMap underlying_map_;
831 };
832 
833 template <typename CTileIdx, typename CTileDim>
834 __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
835  const CTileDim& c_tile_dim)
836 {
837  bool is_valid = false;
838 
839  const index_t m_block = c_tile_dim[Number<0>{}];
840  const index_t n_block = c_tile_dim[Number<1>{}];
841 
842  if constexpr(CTileIdx::Size() == 2)
843  {
844  const index_t m_block_idx = c_tile_idx[Number<0>{}];
845  const index_t n_block_idx = c_tile_idx[Number<1>{}];
846  if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block)
847  {
848  is_valid = true;
849  }
850  }
851  else if constexpr(CTileIdx::Size() == 3)
852  {
853  const index_t ksplit_idx = c_tile_idx[Number<0>{}];
854  const index_t m_block_idx = c_tile_idx[Number<1>{}];
855  const index_t n_block_idx = c_tile_idx[Number<2>{}];
856  if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block)
857  {
858  is_valid = true;
859  }
860  ignore = ksplit_idx;
861  }
862 
863  return is_valid;
864 }
865 
866 // This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the
867 // workgroups assigned to a given gemm problem have top index offsetted to range [0,
868 // grid_size_per_gemm]
869 template <typename UnderlyingBlockToCTileMap>
871 {
872  using underlying_type = UnderlyingBlockToCTileMap;
873 
874  __host__ __device__ OffsettedBlockToCTileMap() = default;
875  __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
876  index_t block_start)
877  {
878  block_to_ctile_map_ = block_to_ctile_map;
879  block_start_ = block_start;
880  }
881 
882  template <typename TopIdx>
883  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
884  {
885  return block_to_ctile_map_.CalculateBottomIndex(
886  make_multi_index(idx_top[Number<0>{}] - block_start_));
887  }
888 
889  template <typename CTileIdx, typename CTileDim>
890  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
891  const CTileDim& c_tile_dim) const
892  {
893  return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
894  }
895 
896  template <typename CGridDesc_M_N>
897  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
898  {
899  return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
900  }
901 
902  template <typename CGridDesc_M_N>
903  __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
904  {
905  return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
906  }
907 
908  __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
909  {
910  return block_to_ctile_map_.CalculateGridSize(M, N);
911  }
912 
913  UnderlyingBlockToCTileMap block_to_ctile_map_;
915 };
916 // second version with 2 offsets
917 template <typename UnderlyingBlockToCTileMap>
919 {
920  using underlying_type = UnderlyingBlockToCTileMap;
921 
922  __host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map,
923  index_t group_offset,
924  index_t tile_offset)
925  : block_to_ctile_map_{block_to_ctile_map},
926  group_offset_{group_offset},
927  tile_offset_{tile_offset}
928  {
929  }
930 
931  template <typename TopIdx>
932  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
933  {
934  return block_to_ctile_map_.CalculateBottomIndex(
936  }
937 
938  template <typename CTileIdx, typename CTileDim>
939  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
940  const CTileDim& c_tile_dim) const
941  {
942  return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
943  }
944 
945  template <typename CGridDesc_M_N>
946  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
947  {
948  return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
949  }
950 
951  __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
952  {
953  return block_to_ctile_map_.CalculateGridSize(M, N);
954  }
955 
956  __device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
957  UnderlyingBlockToCTileMap block_to_ctile_map_;
960 };
961 
974 template <index_t MPerBlock, index_t NPerBlock>
976 {
977 
978  __host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default;
979 
980  __host__ __device__ constexpr auto
982  {
983  // Create 3D grid
984  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
985  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
986  return make_tuple(N0, M0, k_split);
987  }
988 
989  template <typename TopIdx>
990  __device__ constexpr auto CalculateBottomIndex(const TopIdx&) const
991  {
992  return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x);
993  }
994 
995  template <typename CTileIdx, typename CTileDim>
996  __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
997  const CTileDim& /* c_tile_dim */) const
998  {
999  return true; // always valid provided that user gets grid size from CalculateGridSize()
1000  }
1001 
1002  template <typename CGridDesc_M_N>
1003  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
1004  {
1005  return true;
1006  }
1007 };
1008 
1010 {
1011  Atomic = 0, // sk block use atomic to do reduction
1012  Reduction, // let some workgroup responsible for doing the reduction operation
1013 };
1014 
1015 template <uint32_t MPerBlock_,
1016  uint32_t NPerBlock_,
1017  uint32_t KPerBlock_,
1019  uint32_t TileSwizzleSubM_ = 8>
1021 {
1022  static constexpr uint32_t min_k_iters_per_sk_block = 2;
1023  static constexpr uint32_t MPerBlock = MPerBlock_;
1024  static constexpr uint32_t NPerBlock = NPerBlock_;
1025  static constexpr uint32_t KPerBlock = KPerBlock_;
1026  static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
1027  static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
1028 
1029  //--------------------------------------
1030  // pass to device
1031  uint32_t sk_num_blocks;
1038  MDiv eqav_tiles_big; // for reduction
1039  MDiv eqav_tiles_little; // for reduction
1040 
1041  // MDiv tile_swizzle_sub_m_rem;
1042  //--------------------------------------
1043 
1044  // prefer construct on host
1046  uint32_t n,
1047  uint32_t k,
1048  uint32_t num_cu,
1049  uint32_t occupancy,
1050  uint32_t sk_blocks = 0xffffffff)
1051  {
1052  uint32_t num_tiles =
1055 
1056  // one cu can hold one wg at one time, from the whole chip's point of view
1057  // if number of wg is same as num_cu, we call it 1 dispatch
1058  // if number of wg is 2x num_cu, we call it 2 dispatches.
1059  // one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
1060  // dispatch)
1061  //
1062  uint32_t full_dispatches = num_tiles / num_cu;
1063  uint32_t full_dispatch_tiles = full_dispatches * num_cu;
1064  uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
1065 
1066  uint32_t sk_occupancy = occupancy;
1067  uint32_t dp_tiles = full_dispatch_tiles;
1068  uint32_t sk_tiles = partial_dispatche_tiles;
1069 
1070  if(full_dispatches < occupancy)
1071  {
1072  // in this case, we allocate all blocks as sk blocks
1073  // sk_occupancy = occupancy - full_dispatches;
1074  sk_occupancy = 1; // TODO: single occ seems better
1075  dp_tiles = full_dispatch_tiles;
1076  sk_tiles = partial_dispatche_tiles;
1077  }
1078  else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
1079  {
1080  // e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
1081  // occupancy = 3, full_dispatches = 5, 8, 11 ...
1082  // occupancy = 4, full_dispatches = 7, 11 ...
1083  sk_occupancy = 1; // left 1 slot for sk occupancy
1084  dp_tiles = full_dispatch_tiles;
1085  sk_tiles = partial_dispatche_tiles;
1086  }
1087  else
1088  {
1089  // others, we reduce 1 dispatch from dp, together with partial dispatch,
1090  // to construct sk dispatch
1091  sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
1092  dp_tiles = full_dispatch_tiles - num_cu;
1093  sk_tiles = partial_dispatche_tiles + num_cu;
1094  }
1095 
1096  // uint32_t dp_iters_per_block = k_iters_per_tile.get();
1097  uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
1098  uint32_t dp_num_blocks = 0;
1099 
1100  {
1101  uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
1102  uint32_t max_sk_tiles =
1103  (sk_tiles >= num_cu) ? num_cu * sk_occupancy
1104  : math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
1105 
1106  // if use dp for sk-block, how many iters do we need
1107  uint32_t dp_for_sk_iters = k_iters_per_tile.get();
1108 
1109  uint32_t best_sk_score =
1110  NumericLimits<int32_t>::Max(); // we need to find the smallest sk iters
1111  for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
1112  tentative_sk_blocks++)
1113  {
1114  uint32_t tentative_sk_iters_per_block =
1115  (sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
1116  uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
1117  uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
1118 
1119  // TODO: carefully adjust this parameter
1120  // the more sk_blocks_per_tile, the worse the overhead
1121  uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
1122  if(tentative_sk_blocks % sk_tiles != 0)
1123  {
1124  // penalty for uneven divide
1125  cross_sk_blocks_overhead +=
1126  sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
1127  }
1128 
1129  uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
1130 
1131  if(tentative_sk_score < best_sk_score)
1132  {
1133  best_sk_score = tentative_sk_score;
1134  sk_num_blocks = tentative_sk_blocks;
1135  }
1136  }
1137 
1138  if(best_sk_score >= dp_for_sk_iters)
1139  {
1140  sk_num_blocks = 0;
1141  }
1142 
1143  // give a chance to control num of sk blocks
1144  sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
1145 
1146  if(sk_num_blocks == 0)
1147  {
1148  sk_num_big_blocks = 0;
1150 
1151  dp_num_blocks = num_tiles; // all tile to be dp block
1152  dp_start_block_idx = 0;
1153  sk_total_iters = 0; // clear this tiles
1154  }
1155  else
1156  {
1157  // k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
1158  // we need to decide how many iters for each sk block
1159  // let m = k_iters_per_sk_block
1160  // some of the sk block (little) will cover m iters, some (big) will cover m+1
1161  // we have
1162  // 1) l + b = sk_blocks
1163  // 2) l * m + b * (m + 1) = sk_total_iters
1164  // => (l + b) * m + b = sk_total_iters
1165  // => sk_blocks * m + b = sk_total_iters
1166  // => b = sk_total_iters - m * sk_blocks
1167  // NOTE: big could be zero
1168  uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
1169  sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
1170  k_iters_per_big_block = k_iters_per_sk_block + 1;
1171 
1172  dp_num_blocks = dp_tiles;
1173  dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
1174  }
1175  }
1178 
1180  {
1181  uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
1182  uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
1183  eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
1184  eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
1185  }
1186 
1187 #if 0
1188  printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
1189  "sk_num_blocks:%d, "
1190  "sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
1191  "k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
1192  "sk_tiles:%u, workspace(acc float):%u\n",
1193  num_cu,
1194  occupancy,
1195  get_grid_dims().x,
1196  num_tiles,
1197  dp_tiles,
1199  sk_num_blocks,
1200  sk_total_iters,
1202  dp_iters_per_block,
1203  dp_num_blocks,
1207  get_sk_tiles(),
1208  get_workspace_size(sizeof(float)));
1209 #endif
1210  }
1211 
1212  __host__ __device__ uint32_t get_sk_total_iters() const
1213  {
1214  uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
1216  return sk_total_iters;
1217  }
1218 
1219  __host__ __device__ uint32_t get_sk_tiles() const
1220  {
1221  // tiles for sk
1222  uint32_t sk_total_iters = get_sk_total_iters();
1223  return k_iters_per_tile.div(sk_total_iters);
1224  }
1225 
1226  __host__ __device__ dim3 get_grid_dims() const
1227  {
1229  {
1230  return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
1231  }
1232  else
1233  return dim3(reduction_start_block_idx, 1, 1);
1234  }
1235 
1236  __device__ uint32_t get_block_idx() const
1237  {
1238  // TODO: swizzle block index for better locality
1239  return __builtin_amdgcn_readfirstlane(blockIdx.x);
1240  }
1241 
1242  __device__ void
1243  get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
1244  {
1245  if(block_idx < sk_num_big_blocks)
1246  {
1247  iter_start = block_idx * k_iters_per_big_block;
1248  iter_end = iter_start + k_iters_per_big_block;
1249  }
1250  else if(block_idx < sk_num_blocks)
1251  {
1252  iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
1253  (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
1254  iter_end = iter_start + (k_iters_per_big_block - 1);
1255  }
1256  else if(block_idx >= dp_start_block_idx)
1257  {
1258  uint32_t sk_total_iters = get_sk_total_iters();
1259  uint32_t dp_iters_per_block = k_iters_per_tile.get();
1260  iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
1261  iter_end = iter_start + dp_iters_per_block;
1262  }
1263  }
1264 
1265  __device__ uint32_t get_current_iter_length(uint32_t iter_start,
1266  uint32_t iter_end,
1267  uint32_t total_iter_length) const
1268  {
1269  uint32_t iter_length_mod, iter_length_quo /*unused*/;
1270  k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
1271  uint32_t current_iter_length = math::min(
1272  iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
1273  return current_iter_length;
1274  }
1275 
1276  __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
1277 
1278  __device__ void
1279  get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
1280  {
1281  k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
1282  }
1283 
1284  __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
1285  {
1286  uint32_t m_tile_idx, n_tile_idx;
1287  uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
1288  n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
1289 
1290  // swizzle tile
1291  uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
1292 
1293  uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
1294 
1295  const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
1297  : tile_swizzle_sub_m_rem;
1298 
1299  uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
1300  m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
1301  m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
1302 
1303  uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
1304 
1305  uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
1306 
1307  n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
1308  m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
1309  return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
1310  n_tile_idx_with_adapt);
1311  }
1312 
1313  __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
1314  {
1315  static constexpr uint32_t alignment = 128;
1316  uint32_t acc_buffer_bytes =
1317  MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
1318  return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
1319  }
1320 
1321  __host__ __device__ uint32_t get_workspace_size_for_semaphore() const
1322  {
1323  return get_sk_tiles() * sizeof(uint32_t);
1324  }
1325 
1326  __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
1327  {
1328  return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
1329  }
1330 
1331  __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
1332  const MDiv& eqav_tiles_) const
1333  {
1334  uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
1335  uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1;
1336  uint32_t quo_, rem_;
1337  eqav_tiles_.divmod(tile_idx_, quo_, rem_);
1338  return quo_ * max_eqav_tiles_ + rem_;
1339  }
1340 
1341  __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
1342  uint32_t iters_per_sk_block_) const
1343  {
1344  return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1345  1);
1346  }
1347 
1348  __host__ __device__ uint32_t get_total_acc_buffers() const
1349  {
1350  uint32_t tiles_cover_big_blocks =
1352  uint32_t tiles_cover_little_blocks =
1354 
1355  uint32_t total_intersec_big =
1356  get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big);
1357  uint32_t total_intersec_little =
1358  get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little);
1359 
1360  return sk_num_blocks + total_intersec_big + total_intersec_little;
1361  }
1362 
1363  __device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
1364  {
1365  // TODO: from big to little
1366  uint32_t tiles_cover_big_blocks =
1368  if(tile_idx_ < tiles_cover_big_blocks)
1369  {
1370  uint32_t touched_sk_blocks =
1371  (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
1373  uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big);
1374  return touched_sk_blocks + current_intersec;
1375  }
1376  else
1377  {
1378  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1379  uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
1380  uint32_t touched_sk_blocks =
1381  (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
1382  iters_per_little_sk_block;
1383  uint32_t current_intersec =
1384  get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little);
1385  return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
1386  }
1387  }
1388 
1389  __device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
1390  {
1391  uint32_t iters_per_big_sk_block = k_iters_per_big_block;
1392  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1393  if(block_idx_ < sk_num_big_blocks)
1394  {
1395  uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
1396  k_iters_per_tile.get() - 1);
1397  uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big);
1398  return block_idx_ + current_intersec;
1399  }
1400  else
1401  {
1402  uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
1403  uint32_t touched_tiles = k_iters_per_tile.div(
1404  block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
1405  uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little);
1406  return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
1407  }
1408  }
1409 };
1410 
1411 template <uint32_t MPerBlock_,
1412  uint32_t NPerBlock_,
1413  uint32_t KPerBlock_,
1415  uint32_t TileSwizzleSubM_ = 8,
1416  index_t GroupNum = 8,
1417  index_t M01_ = 4>
1419 {
1420  static constexpr uint32_t min_k_iters_per_sk_block = 2;
1421  static constexpr uint32_t MPerBlock = MPerBlock_;
1422  static constexpr uint32_t NPerBlock = NPerBlock_;
1423  static constexpr uint32_t KPerBlock = KPerBlock_;
1424  static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
1425 
1426  //--------------------------------------
1427  // pass to device
1428  mutable uint32_t sk_num_blocks;
1435  MDiv equiv_tiles_big; // for reduction
1436  MDiv equiv_tiles_little; // for reduction
1438 
1439  // prefer construct on host
1440  __host__ __device__ BlockToCTileMap_GemmStreamK_v2(
1441  uint32_t m,
1442  uint32_t n,
1443  uint32_t k,
1444  uint32_t grid_size = 1,
1445  uint32_t streamk_sel = 1,
1447  : reduction_strategy(reduction_strategy_)
1448  {
1449 
1450  // total output tiles
1451  uint32_t num_tiles =
1454 
1455  uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
1456 
1457  // Ensure grid_size is at least 1 to avoid division by zero
1458  grid_size = math::max(grid_size, 1u);
1459 
1460  // default to regular DP GEMM if sk blocks == 0
1461  if(streamk_sel == 0)
1462  {
1463  sk_num_blocks = 0;
1464  dp_tiles = num_tiles;
1465  sk_num_big_blocks = 0;
1467 
1468  dp_num_blocks = num_tiles; // all tile to be dp block
1469  dp_start_block_idx = 0;
1470  sk_total_iters = 0; // clear this tiles
1471  }
1472  // 2-tile sk + DP GEMM
1473  else
1474  {
1475  // check if there's enough work for DP+ stream-k
1476  bool bigEnough = num_tiles > grid_size;
1477 
1478  // Select between stream-k strategies
1479  // Add safety checks to prevent zero or negative values
1480  uint32_t sk_tiles = 0;
1481  if(streamk_sel == 1) // 1 tile stream-k
1482  {
1483  sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles;
1484 
1485  // Ensure sk_tiles is at least 1
1486  sk_tiles = math::max(sk_tiles, 1u);
1487  }
1488  else if(streamk_sel == 2) // 2-tile stream-k
1489  {
1490  sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles;
1491 
1492  // Ensure sk_tiles is at least 1 but not more than num_tiles
1493  sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
1494  }
1495  else if(streamk_sel == 3) // 3-tile stream-k
1496  {
1497  sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size)
1498  : num_tiles;
1499 
1500  // Ensure sk_tiles is at least 1 but not more than num_tiles
1501  sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
1502  }
1503  else if(streamk_sel == 4) // 4-tile stream-k
1504  {
1505  sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size)
1506  : num_tiles;
1507 
1508  // Ensure sk_tiles is at least 1 but not more than num_tiles
1509  sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
1510  }
1511 
1512  sk_num_blocks = sk_tiles;
1513  // Remaining tiles are DP tiles
1514  dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
1515 
1516  sk_total_iters = k_iters_per_tile.get() * sk_tiles;
1517 
1518  // k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
1519  // we need to decide how many iters for each sk block
1520  // let m = k_iters_per_sk_block
1521  // some of the sk block (little) will cover m iters, some (big) will cover m+1
1522  // we have
1523  // 1) l + b = sk_blocks
1524  // 2) l * m + b * (m + 1) = sk_total_iters
1525  // => (l + b) * m + b = sk_total_iters
1526  // => sk_blocks * m + b = sk_total_iters
1527  // => b = sk_total_iters - m * sk_blocks
1528  // NOTE: big could be zero
1529 
1530  // Add safety check for sk_num_blocks to prevent division by zero
1531  if(sk_num_blocks > 0)
1532  {
1533  uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
1534  sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
1535  k_iters_per_big_block = k_iters_per_sk_block + 1;
1536  }
1537  else
1538  {
1539  // Fallback to default GEMM if no stream-k blocks
1540  sk_num_blocks = 0;
1541  sk_num_big_blocks = 0;
1543  dp_tiles = num_tiles;
1544  dp_num_blocks = num_tiles;
1545  dp_start_block_idx = 0;
1546  sk_total_iters = 0;
1547  }
1548 
1549  dp_num_blocks = dp_tiles;
1551  }
1552 
1554  // Using multiple blocks for parallel reduction
1556 
1558  {
1559  // Add additional safety checks
1560  if(k_iters_per_big_block > 0 && k_iters_per_tile.get() > 0)
1561  {
1562  uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
1563  uint32_t upper_little =
1565  equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
1566  equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
1567  }
1568  else
1569  {
1570  // Default safe values
1571  equiv_tiles_big = MDiv(1);
1572  equiv_tiles_little = MDiv(1);
1573  }
1574  }
1575  }
1576 
1577  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
1578  {
1579  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
1580  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
1581 
1582  return M0 * N0;
1583  }
1584  __host__ __device__ uint32_t get_sk_total_iters() const
1585  {
1586  uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
1588  return sk_total_iters;
1589  }
1590 
1591  __host__ __device__ uint32_t get_sk_tiles() const
1592  {
1593  // tiles for sk
1594  uint32_t sk_total_iters = get_sk_total_iters();
1595  return k_iters_per_tile.div(sk_total_iters);
1596  }
1597 
1598  __host__ __device__ index_t get_grid_dims() const
1599  {
1601  {
1602  // return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
1604  }
1605  else
1607  }
1608 
1609  __device__ uint32_t get_block_idx() const
1610  {
1611  // TODO: swizzle block index for better locality
1612  return __builtin_amdgcn_readfirstlane(blockIdx.x);
1613  }
1614 
1615  __device__ void
1616  get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
1617  {
1618  if(block_idx < sk_num_big_blocks)
1619  {
1620  iter_start = block_idx * k_iters_per_big_block;
1621  iter_end = iter_start + k_iters_per_big_block;
1622  }
1623  else if(block_idx < sk_num_blocks)
1624  {
1625  iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
1626  (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
1627  iter_end = iter_start + (k_iters_per_big_block - 1);
1628  }
1629  else if(block_idx >= dp_start_block_idx)
1630  {
1631  uint32_t sk_total_iters = get_sk_total_iters();
1632  uint32_t dp_iters_per_block = k_iters_per_tile.get();
1633  iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
1634  iter_end = iter_start + dp_iters_per_block;
1635  }
1636  }
1637 
1638  __device__ uint32_t get_current_iter_length(uint32_t iter_start,
1639  uint32_t iter_end,
1640  uint32_t total_iter_length) const
1641  {
1642  uint32_t iter_length_mod, iter_length_quo /*unused*/;
1643  k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
1644  uint32_t current_iter_length = math::min(
1645  iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
1646  return current_iter_length;
1647  }
1648 
1649  __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
1650 
1651  __device__ void
1652  get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
1653  {
1654  k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
1655  }
1656 
1657  __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
1658  {
1659  uint32_t m_tile_idx, n_tile_idx;
1660  uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
1661  n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
1662 
1663  // // swizzle tile
1664  uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
1665 
1666  uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
1667 
1668  const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
1670  : tile_swizzle_sub_m_rem;
1671 
1672  uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
1673  m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
1674  m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
1675 
1676  uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
1677 
1678  uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
1679 
1680  n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
1681  m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
1682  return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
1683  n_tile_idx_with_adapt);
1684  }
1685 
1686  __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
1687  {
1688  static constexpr uint32_t alignment = 128;
1689  uint32_t acc_buffer_bytes =
1690  MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
1691  return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
1692  }
1693 
1694  __host__ __device__ uint32_t get_workspace_size_for_semaphore() const
1695  {
1696  return get_sk_tiles() * sizeof(uint32_t);
1697  }
1698 
1699  __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
1700  {
1701  return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
1702  }
1703 
1704  __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
1705  const MDiv& equiv_tiles_) const
1706  {
1707  uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
1708  uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1;
1709  uint32_t quo_, rem_;
1710  equiv_tiles_.divmod(tile_idx_, quo_, rem_);
1711  return quo_ * max_equiv_tiles_ + rem_;
1712  }
1713 
1714  __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
1715  uint32_t iters_per_sk_block_) const
1716  {
1717  return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1718  1);
1719  }
1720 
1721  __host__ __device__ uint32_t get_total_acc_buffers() const
1722  {
1723  uint32_t tiles_cover_big_blocks =
1725  uint32_t tiles_cover_little_blocks =
1727 
1728  uint32_t total_intersec_big =
1729  get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big);
1730  uint32_t total_intersec_little =
1731  get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little);
1732 
1733  return sk_num_blocks + total_intersec_big + total_intersec_little;
1734  }
1735 
1736  __device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
1737  {
1738  // TODO: from big to little
1739  uint32_t tiles_cover_big_blocks =
1741  if(tile_idx_ < tiles_cover_big_blocks)
1742  {
1743  uint32_t touched_sk_blocks =
1744  (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
1746  uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big);
1747  return touched_sk_blocks + current_intersec;
1748  }
1749  else
1750  {
1751  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1752  uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
1753  uint32_t touched_sk_blocks =
1754  (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
1755  iters_per_little_sk_block;
1756  uint32_t current_intersec =
1757  get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little);
1758  return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
1759  }
1760  }
1761 
1762  __device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
1763  {
1764  uint32_t iters_per_big_sk_block = k_iters_per_big_block;
1765  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1766  if(block_idx_ < sk_num_big_blocks)
1767  {
1768  uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
1769  k_iters_per_tile.get() - 1);
1770  uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big);
1771  return block_idx_ + current_intersec;
1772  }
1773  else
1774  {
1775  uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
1776  uint32_t touched_tiles = k_iters_per_tile.div(
1777  block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
1778  uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little);
1779  return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
1780  }
1781  }
1782 };
1783 
1784 } // namespace ck
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: ck.hpp:266
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
StreamKReductionStrategy
Definition: block_to_ctile_map.hpp:1010
@ Atomic
Definition: block_to_ctile_map.hpp:1011
@ Reduction
Definition: block_to_ctile_map.hpp:1012
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:54
__host__ __device__ bool DefaultValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim)
Definition: block_to_ctile_map.hpp:834
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_insert_transform(const UpperIndex &up_idx)
Definition: multi_index_transform_helper.hpp:104
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:297
__host__ constexpr __device__ auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition: tensor_adaptor.hpp:245
Simple tile mapping which creates 3D grid of block of threads.
Definition: block_to_ctile_map.hpp:976
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:1003
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:996
constexpr __device__ auto CalculateBottomIndex(const TopIdx &) const
Definition: block_to_ctile_map.hpp:990
__host__ constexpr __device__ auto CalculateGridSize(index_t M, index_t N, index_t k_split) const
Definition: block_to_ctile_map.hpp:981
__host__ __device__ BlockToCTileMap_3DGrid_KSplit()=default
Definition: block_to_ctile_map.hpp:1419
__host__ __device__ uint32_t get_sk_tiles() const
Definition: block_to_ctile_map.hpp:1591
MDiv k_iters_per_tile
Definition: block_to_ctile_map.hpp:1434
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
Definition: block_to_ctile_map.hpp:1699
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, const MDiv &equiv_tiles_) const
Definition: block_to_ctile_map.hpp:1704
MDiv equiv_tiles_little
Definition: block_to_ctile_map.hpp:1436
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
Definition: block_to_ctile_map.hpp:1762
uint32_t dp_start_block_idx
Definition: block_to_ctile_map.hpp:1430
static constexpr uint32_t KPerBlock
Definition: block_to_ctile_map.hpp:1423
static constexpr uint32_t NPerBlock
Definition: block_to_ctile_map.hpp:1422
MDiv2 n_tiles
Definition: block_to_ctile_map.hpp:1433
static constexpr uint32_t min_k_iters_per_sk_block
Definition: block_to_ctile_map.hpp:1420
__host__ __device__ uint32_t get_sk_total_iters() const
Definition: block_to_ctile_map.hpp:1584
__host__ __device__ uint32_t get_total_acc_buffers() const
Definition: block_to_ctile_map.hpp:1721
__host__ __device__ index_t get_grid_dims() const
Definition: block_to_ctile_map.hpp:1598
__device__ uint32_t get_tile_idx(uint32_t iter) const
Definition: block_to_ctile_map.hpp:1649
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
Definition: block_to_ctile_map.hpp:1694
__device__ void get_block_itr(uint32_t block_idx, uint32_t &iter_start, uint32_t &iter_end) const
Definition: block_to_ctile_map.hpp:1616
uint32_t k_iters_per_big_block
Definition: block_to_ctile_map.hpp:1432
uint32_t sk_num_big_blocks
Definition: block_to_ctile_map.hpp:1429
MDiv equiv_tiles_big
Definition: block_to_ctile_map.hpp:1435
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
Definition: block_to_ctile_map.hpp:1736
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
Definition: block_to_ctile_map.hpp:1657
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:1577
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
Definition: block_to_ctile_map.hpp:1686
__device__ uint32_t get_current_iter_length(uint32_t iter_start, uint32_t iter_end, uint32_t total_iter_length) const
Definition: block_to_ctile_map.hpp:1638
static constexpr uint32_t tile_swizzle_sub_m
Definition: block_to_ctile_map.hpp:1424
uint32_t reduction_start_block_idx
Definition: block_to_ctile_map.hpp:1431
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, uint32_t iters_per_sk_block_) const
Definition: block_to_ctile_map.hpp:1714
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size=1, uint32_t streamk_sel=1, StreamKReductionStrategy reduction_strategy_=StreamKReductionStrategy::Atomic)
Definition: block_to_ctile_map.hpp:1440
uint32_t sk_num_blocks
Definition: block_to_ctile_map.hpp:1428
__device__ uint32_t get_block_idx() const
Definition: block_to_ctile_map.hpp:1609
__device__ void get_tile_idx_with_offset(uint32_t iter, uint32_t &tile_idx, uint32_t &iter_offset) const
Definition: block_to_ctile_map.hpp:1652
StreamKReductionStrategy reduction_strategy
Definition: block_to_ctile_map.hpp:1437
static constexpr uint32_t MPerBlock
Definition: block_to_ctile_map.hpp:1421
Definition: block_to_ctile_map.hpp:1021
uint32_t k_iters_per_big_block
Definition: block_to_ctile_map.hpp:1035
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
Definition: block_to_ctile_map.hpp:1326
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
Definition: block_to_ctile_map.hpp:1389
__host__ __device__ uint32_t get_sk_total_iters() const
Definition: block_to_ctile_map.hpp:1212
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, uint32_t iters_per_sk_block_) const
Definition: block_to_ctile_map.hpp:1341
static constexpr uint32_t MPerBlock
Definition: block_to_ctile_map.hpp:1023
uint32_t dp_start_block_idx
Definition: block_to_ctile_map.hpp:1033
__host__ __device__ uint32_t get_sk_tiles() const
Definition: block_to_ctile_map.hpp:1219
static constexpr uint32_t KPerBlock
Definition: block_to_ctile_map.hpp:1025
__host__ __device__ uint32_t get_total_acc_buffers() const
Definition: block_to_ctile_map.hpp:1348
__device__ uint32_t get_current_iter_length(uint32_t iter_start, uint32_t iter_end, uint32_t total_iter_length) const
Definition: block_to_ctile_map.hpp:1265
static constexpr uint32_t NPerBlock
Definition: block_to_ctile_map.hpp:1024
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
Definition: block_to_ctile_map.hpp:1363
uint32_t reduction_start_block_idx
Definition: block_to_ctile_map.hpp:1034
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
Definition: block_to_ctile_map.hpp:1313
MDiv k_iters_per_tile
Definition: block_to_ctile_map.hpp:1037
__device__ void get_tile_idx_with_offset(uint32_t iter, uint32_t &tile_idx, uint32_t &iter_offset) const
Definition: block_to_ctile_map.hpp:1279
static constexpr uint32_t tile_swizzle_sub_m
Definition: block_to_ctile_map.hpp:1027
BlockToCTileMap_GemmStreamK(uint32_t m, uint32_t n, uint32_t k, uint32_t num_cu, uint32_t occupancy, uint32_t sk_blocks=0xffffffff)
Definition: block_to_ctile_map.hpp:1045
static constexpr StreamKReductionStrategy ReductionStrategy
Definition: block_to_ctile_map.hpp:1026
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
Definition: block_to_ctile_map.hpp:1284
__device__ uint32_t get_tile_idx(uint32_t iter) const
Definition: block_to_ctile_map.hpp:1276
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, const MDiv &eqav_tiles_) const
Definition: block_to_ctile_map.hpp:1331
__device__ uint32_t get_block_idx() const
Definition: block_to_ctile_map.hpp:1236
__device__ void get_block_itr(uint32_t block_idx, uint32_t &iter_start, uint32_t &iter_end) const
Definition: block_to_ctile_map.hpp:1243
MDiv eqav_tiles_little
Definition: block_to_ctile_map.hpp:1039
uint32_t sk_num_blocks
Definition: block_to_ctile_map.hpp:1031
MDiv2 n_tiles
Definition: block_to_ctile_map.hpp:1036
MDiv eqav_tiles_big
Definition: block_to_ctile_map.hpp:1038
static constexpr uint32_t min_k_iters_per_sk_block
Definition: block_to_ctile_map.hpp:1022
uint32_t sk_num_big_blocks
Definition: block_to_ctile_map.hpp:1032
__host__ __device__ dim3 get_grid_dims() const
Definition: block_to_ctile_map.hpp:1226
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
Definition: block_to_ctile_map.hpp:1321
Definition: block_to_ctile_map.hpp:270
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:297
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:282
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt()=default
static constexpr auto I1
Definition: block_to_ctile_map.hpp:272
__host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:291
static constexpr auto I0
Definition: block_to_ctile_map.hpp:271
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:383
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M, index_t N, index_t M01=8)
Definition: block_to_ctile_map.hpp:275
Definition: block_to_ctile_map.hpp:719
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:754
static constexpr auto I2
Definition: block_to_ctile_map.hpp:722
static constexpr auto I0
Definition: block_to_ctile_map.hpp:720
__host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=1, index_t N01=1, index_t KSplit=1)
Definition: block_to_ctile_map.hpp:727
__host__ BlockToCTileMap_KSplit_M00_N00_M01_N01()=default
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:772
static constexpr auto I3
Definition: block_to_ctile_map.hpp:723
static constexpr auto I1
Definition: block_to_ctile_map.hpp:721
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:763
__host__ constexpr __device__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:740
Definition: block_to_ctile_map.hpp:540
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=8, index_t KSplit=1)
Definition: block_to_ctile_map.hpp:548
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt()=default
static constexpr auto I0
Definition: block_to_ctile_map.hpp:541
static constexpr auto I1
Definition: block_to_ctile_map.hpp:542
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:593
constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:555
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:566
static constexpr auto I2
Definition: block_to_ctile_map.hpp:543
static constexpr auto I3
Definition: block_to_ctile_map.hpp:544
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:599
Definition: block_to_ctile_map.hpp:616
__host__ __device__ BlockToCTileMap_M00_N00_M01_N01()=default
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:660
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:645
__host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=1, index_t N01=1)
Definition: block_to_ctile_map.hpp:624
constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:631
static constexpr auto I0
Definition: block_to_ctile_map.hpp:617
static constexpr auto I3
Definition: block_to_ctile_map.hpp:620
static constexpr auto I1
Definition: block_to_ctile_map.hpp:618
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:651
static constexpr auto I2
Definition: block_to_ctile_map.hpp:619
__host__ constexpr __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:245
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:157
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt()=default
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt &&)=default
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt & operator=(BlockToCTileMap_M00_N0_M01Adapt &&)=default
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt &)=default
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt & operator=(const BlockToCTileMap_M00_N0_M01Adapt &)=default
static constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: block_to_ctile_map.hpp:166
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=8)
Definition: block_to_ctile_map.hpp:150
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01=8)
Definition: block_to_ctile_map.hpp:138
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:178
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:172
Definition: block_to_ctile_map.hpp:260
Definition: block_to_ctile_map.hpp:24
constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:38
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:66
static constexpr auto I3
Definition: block_to_ctile_map.hpp:28
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=1)
Definition: block_to_ctile_map.hpp:32
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:51
static constexpr auto I2
Definition: block_to_ctile_map.hpp:27
static constexpr auto I0
Definition: block_to_ctile_map.hpp:25
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01()=default
static constexpr auto I1
Definition: block_to_ctile_map.hpp:26
__host__ constexpr __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:57
__host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:450
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N &c_grid_desc_m_n, index_t N01=8)
Definition: block_to_ctile_map.hpp:428
static constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: block_to_ctile_map.hpp:444
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt & operator=(const BlockToCTileMap_N00_M0_N01Adapt &)=default
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:524
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt & operator=(BlockToCTileMap_N00_M0_N01Adapt &&)=default
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:456
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const BlockToCTileMap_N00_M0_N01Adapt &)=default
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(BlockToCTileMap_N00_M0_N01Adapt &&)=default
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01=8)
Definition: block_to_ctile_map.hpp:417
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:435
Definition: block_to_ctile_map.hpp:398
Definition: magic_division.hpp:208
__host__ __device__ void divmod(uint32_t dividend_, uint32_t divisor_, uint32_t &quotient_, uint32_t &remainder_) const
Definition: magic_division.hpp:230
Definition: magic_division.hpp:166
__host__ __device__ uint32_t get() const
Definition: magic_division.hpp:204
__host__ __device__ void divmod(uint32_t dividend_, uint32_t &quotient_, uint32_t &remainder_) const
Definition: magic_division.hpp:198
__host__ __device__ uint32_t div(uint32_t dividend_) const
Definition: magic_division.hpp:192
__host__ static constexpr __device__ T Max()
Definition: numeric_limits.hpp:311
Definition: block_to_ctile_map.hpp:919
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:939
index_t tile_offset_
Definition: block_to_ctile_map.hpp:959
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:946
UnderlyingBlockToCTileMap block_to_ctile_map_
Definition: block_to_ctile_map.hpp:957
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:932
__host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map, index_t group_offset, index_t tile_offset)
Definition: block_to_ctile_map.hpp:922
UnderlyingBlockToCTileMap underlying_type
Definition: block_to_ctile_map.hpp:920
index_t group_offset_
Definition: block_to_ctile_map.hpp:958
__device__ void UpdateTileOffset(index_t offset)
Definition: block_to_ctile_map.hpp:956
__host__ constexpr __device__ index_t CalculateGridSize(index_t M, index_t N) const
Definition: block_to_ctile_map.hpp:951
Definition: block_to_ctile_map.hpp:871
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:890
__host__ constexpr __device__ index_t CalculateGridSize(index_t M, index_t N) const
Definition: block_to_ctile_map.hpp:908
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:897
constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:903
index_t block_start_
Definition: block_to_ctile_map.hpp:914
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:883
__host__ __device__ OffsettedBlockToCTileMap()=default
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start)
Definition: block_to_ctile_map.hpp:875
UnderlyingBlockToCTileMap underlying_type
Definition: block_to_ctile_map.hpp:872
UnderlyingBlockToCTileMap block_to_ctile_map_
Definition: block_to_ctile_map.hpp:913
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20