/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp Source File
block_to_ctile_map.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/utility/math.hpp"
7 #include "ck/utility/number.hpp"
8 #include "ck/utility/tuple.hpp"
11 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
12 #include <limits>
13 #include <stdlib.h>
14 #endif
15 
16 namespace ck {
17 
18 // Rows of column-vectors
19 template <index_t MPerBlock,
20  index_t NPerBlock,
21  typename CGridDesc_M_N,
22  bool DeviceCTileIndexCheck = false>
24 {
25  static constexpr auto I0 = Number<0>{};
26  static constexpr auto I1 = Number<1>{};
27  static constexpr auto I2 = Number<2>{};
28  static constexpr auto I3 = Number<3>{};
29 
30  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01() = default;
31 
32  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n,
33  index_t M01 = 1)
34  : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01))
35  {
36  }
37 
38  __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
39  {
40  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
41  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
42 
43  const auto M00 = math::integer_divide_ceil(M0, M01_);
44 
45  const index_t grid_size = M00 * M01_ * N0;
46 
47  return grid_size;
48  }
49 
50  template <typename TopIdx>
51  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
52  {
53  return underlying_map_.CalculateBottomIndex(idx_top);
54  }
55 
56  template <typename CTileIdx, typename CTileDim>
57  __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& c_tile_idx,
58  const CTileDim& c_tile_dim) const
59  {
60  if constexpr(DeviceCTileIndexCheck)
61  return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
62  else
63  return true;
64  }
65 
66  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
67  {
68  if constexpr(DeviceCTileIndexCheck)
69  return true; // validity check moved to kernel
70 
71  const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
72  if(M0 % M01_ == 0)
73  {
74  return true;
75  }
76  else
77  {
78  return false;
79  }
80  }
81 
82  private:
83  __host__ __device__ static constexpr auto
84  GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01)
85  {
86  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
87  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
88 
89  const auto M00 = math::integer_divide_ceil(M0, M01);
90 
91  const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
95  make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
96  make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
97 
98  const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor(
99  make_tuple(make_merge_transform(make_tuple(1, M00, N0, M01))),
100  make_tuple(Sequence<0, 1, 2, 3>{}),
101  make_tuple(Sequence<0>{}));
102 
103  const auto cblockid_to_m0_n0_block_cluster_adaptor =
104  chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor,
105  cblockid_to_m00_n0_m01_block_cluster_adaptor);
106 
107  return cblockid_to_m0_n0_block_cluster_adaptor;
108  }
109 
110  index_t M01_;
111  using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1));
112  UnderlyingMap underlying_map_;
113 };
114 
115 // Rows of column-vectors
116 // This C-tile map dynamically adjusts M01 when C-tile index is out of range
117 template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
118 struct BlockToCTileMap_M00_N0_M01Adapt;
119 
120 template <index_t MPerBlock, index_t NPerBlock>
121 struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
122 {
123  static constexpr auto I0 = Number<0>{};
124  static constexpr auto I1 = Number<1>{};
125 
126  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
127 
128  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
129  const BlockToCTileMap_M00_N0_M01Adapt&) = default;
130  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
131  BlockToCTileMap_M00_N0_M01Adapt&&) = default;
132  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
134  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
136 
137  __host__
138  __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
139  : M_(M), N_(N), M01_(M01)
140  {
141 #if 0
142  if(get_thread_global_1d_id()==0){
143  printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
144  }
145 #endif
146  }
147 
148  template <typename CGridDesc_M_N>
149  __host__
150  __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
151  index_t M01 = 8)
153  c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
154  {
155  }
156 
157  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
158  {
159  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
160  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
161 
162  return M0 * N0;
163  }
164 
165  template <typename CGridDesc_M_N>
166  __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
167  {
168  return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
169  }
170 
171  template <typename CGridDesc_M_N>
172  __host__ __device__ constexpr bool
173  CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
174  {
175  return true;
176  }
177 
178  template <typename TopIdx>
179  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
180  {
181  auto block_1d_id = idx_top[I0];
182 
183  const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
184  const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
185 
186  block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
187 
188  index_t idx_N0 = block_1d_id % N0;
189  index_t idx_M0 = block_1d_id / N0;
190 
191  const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
192 
193  index_t idx_M00 = idx_M0 / M01_;
194  index_t idx_M01 = idx_M0 % M01_;
195  index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
196 
241  return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
242  idx_N0_M01_local / M01_adapt);
243  }
244 
245  template <typename CTileIdx, typename CTileDim>
246  __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
247  const CTileDim& /* c_tile_dim */) const
248  {
249  return true; // always valid provided that user gets grid size from CalculateGridSize()
250  }
251 
252  private:
253  index_t M_;
254  index_t N_;
255  index_t M01_;
256 };
257 
258 // keep the redundant type argument for backward compatibility
259 template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
261 {
263  BlockToCTileMap_M00_N0_M01Adapt;
264 };
265 
266 // Grouped Rows of column-vectors WGP mapping
267 // Optimized for gfx94x-like multipe-die chip
268 
269 template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
271 {
272  static constexpr auto I0 = Number<0>{};
273  static constexpr auto I1 = Number<1>{};
274 
275  __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default;
277  index_t N,
278  index_t M01 = 8)
279  : M_(M), N_(N), M01_(M01)
280  {
281  }
282 
283  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
284  {
285  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
286  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
287 
288  return M0 * N0;
289  }
290 
291  template <typename CGridDesc_M_N>
292  __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
293  {
294  return true;
295  }
296 
297  template <typename TopIdx>
298  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
299  {
300  auto block_1d_id = idx_top[I0];
301 
302  const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
303  const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
304 
305  if(M0 == 1)
306  {
307  return make_tuple(0, block_1d_id);
308  }
309  else if(N0 == 1)
310  {
311  return make_tuple(block_1d_id, 0);
312  }
313  // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
314  else
315  {
316  const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
317  const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
318  auto group_id_x = block_1d_id % GroupNum;
319  auto group_id_y = block_1d_id / GroupNum;
320  auto remap_block_1d_id =
321  group_id_x <= big_group_num
322  ? group_id_x * group_size + group_id_y
323  : group_id_x * group_size + big_group_num - group_id_x + group_id_y;
324 
325  index_t idx_N0 = remap_block_1d_id % N0;
326  index_t idx_M0 = remap_block_1d_id / N0;
327 
328  const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
329 
330  index_t idx_M00 = idx_M0 / M01_;
331  index_t idx_M01 = idx_M0 % M01_;
332  index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
333 
378  return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
379  idx_N0_M01_local / M01_adapt);
380  }
381  }
382 
383  template <typename CTileIdx, typename CTileDim>
384  __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
385  const CTileDim& /* c_tile_dim */) const
386  {
387  return true; // always valid provided that user gets grid size from CalculateGridSize()
388  }
389 
390  private:
391  index_t M_;
392  index_t N_;
393  index_t M01_;
394 };
395 
396 // columns of row-vectors
397 // This C-tile map dynamically adjusts N01 when C-tile index is out of range
398 template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
400 
401 template <index_t MPerBlock, index_t NPerBlock>
402 struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
403 {
404  static constexpr auto I0 = Number<0>{};
405  static constexpr auto I1 = Number<1>{};
406 
407  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt() = default;
408 
410  default;
412  default;
413  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
415  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
417 
418  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01 = 8)
419  : M_(M), N_(N), N01_(N01)
420  {
421 #if 0
422  if(get_thread_global_1d_id()==0){
423  printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_);
424  }
425 #endif
426  }
427 
428  template <typename CGridDesc_M_N>
429  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
430  index_t N01 = 8)
432  c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), N01)
433  {
434  }
435 
436  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
437  {
438  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
439  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
440 
441  return M0 * N0;
442  }
443 
444  template <typename CGridDesc_M_N>
445  __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
446  {
447  return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
448  }
449 
450  template <typename CGridDesc_M_N>
451  __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
452  {
453  return true;
454  }
455 
456  template <typename TopIdx>
457  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
458  {
459  auto block_1d_id = idx_top[I0];
460 
461  const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
462  const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
463 
464  block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
465 
466  index_t idx_M0 = block_1d_id % M0;
467  index_t idx_N0 = block_1d_id / M0;
468 
469  const auto N01_adapt = (idx_N0 < N0 - N0 % N01_) ? N01_ : N0 % N01_;
470 
471  index_t idx_N00 = idx_N0 / N01_;
472  index_t idx_N01 = idx_N0 % N01_;
473  index_t idx_M0_N01_local = idx_M0 + idx_N01 * M0;
474 
520  return make_tuple(idx_M0_N01_local / N01_adapt,
521  idx_M0_N01_local % N01_adapt + idx_N00 * N01_);
522  }
523 
524  template <typename CTileIdx, typename CTileDim>
525  __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
526  const CTileDim& /* c_tile_dim */) const
527  {
528  return true; // always valid provided that user gets grid size from CalculateGridSize()
529  }
530 
531  private:
532  index_t M_;
533  index_t N_;
534  index_t N01_;
535 };
536 
537 // 2D slices of column-vectors in 3D space
538 // This C-tile map dynamically adjusts M01 when C-tile index is out of range
539 template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
541 {
542  static constexpr auto I0 = Number<0>{};
543  static constexpr auto I1 = Number<1>{};
544  static constexpr auto I2 = Number<2>{};
545  static constexpr auto I3 = Number<3>{};
546 
547  __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt() = default;
548 
549  __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
550  index_t M01 = 8,
551  index_t KSplit = 1)
552  : M01_(M01), KSplit_(KSplit), c_grid_desc_m_n_(c_grid_desc_m_n)
553  {
554  }
555 
556  __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
557  {
558  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
559  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
560 
561  const index_t grid_size = M0 * N0 * KSplit_;
562 
563  return grid_size;
564  }
565 
566  template <typename TopIdx>
567  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
568  {
569  auto block_1d_id = idx_top[I0];
570 
571  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
572  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
573 
574  block_1d_id = block_1d_id % (M0 * N0 * KSplit_); // hide groups
575 
576  const index_t idx_ksplit = block_1d_id / (M0 * N0);
577  block_1d_id = block_1d_id % (M0 * N0);
578 
579  index_t idx_N0 = block_1d_id % N0;
580  index_t idx_M0 = block_1d_id / N0;
581 
582  const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
583 
584  index_t idx_M00 = idx_M0 / M01_;
585  index_t idx_M01 = idx_M0 % M01_;
586  index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
587 
588  return make_tuple(idx_ksplit,
589  idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
590  idx_N0_M01_local / M01_adapt);
591  }
592 
593  template <typename CTileIdx, typename CTileDim>
594  __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
595  const CTileDim& /* c_tile_dim */) const
596  {
597  return true; // always valid provided that user gets grid size from CalculateGridSize()
598  }
599 
600  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
601  {
602  return true;
603  }
604 
605  private:
606  index_t M01_;
607  index_t KSplit_;
608  CGridDesc_M_N c_grid_desc_m_n_;
609 };
610 
611 // Blocks of row-vectors
612 template <index_t MPerBlock,
613  index_t NPerBlock,
614  typename CGridDesc_M_N,
615  bool DeviceCTileIndexCheck = false>
617 {
618  static constexpr auto I0 = Number<0>{};
619  static constexpr auto I1 = Number<1>{};
620  static constexpr auto I2 = Number<2>{};
621  static constexpr auto I3 = Number<3>{};
622 
623  __host__ __device__ BlockToCTileMap_M00_N00_M01_N01() = default;
624 
625  __host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n,
626  index_t M01 = 1,
627  index_t N01 = 1)
628  : M01_(M01), N01_(N01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01))
629  {
630  }
631 
632  __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
633  {
634  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
635  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
636 
637  const auto M00 = math::integer_divide_ceil(M0, M01_);
638  const auto N00 = math::integer_divide_ceil(N0, N01_);
639 
640  const index_t grid_size = M00 * M01_ * N00 * N01_;
641 
642  return grid_size;
643  }
644 
645  template <typename TopIdx>
646  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
647  {
648  return underlying_map_.CalculateBottomIndex(idx_top);
649  }
650 
651  template <typename CTileIdx, typename CTileDim>
652  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
653  const CTileDim& c_tile_dim) const
654  {
655  if constexpr(DeviceCTileIndexCheck)
656  return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
657  else
658  return true;
659  }
660 
661  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
662  {
663  if constexpr(DeviceCTileIndexCheck)
664  return true; // validity check moved to kernel
665 
666  const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
667  const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
668  if(M0 % M01_ == 0 && N0 % N01_ == 0)
669  {
670  return true;
671  }
672  else
673  {
674  return false;
675  }
676  }
677 
678  private:
679  __host__ __device__ static constexpr auto
680  GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
681  {
682  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
683  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
684 
685  const auto M00 = math::integer_divide_ceil(M0, M01);
686  const auto N00 = math::integer_divide_ceil(N0, N01);
687 
688  const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
690  make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions
693  make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
694  make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
695 
696  const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
698  make_tuple(make_merge_transform(make_tuple(1, M00, N00, M01, N01))),
699  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
700  make_tuple(Sequence<0>{}));
701 
702  const auto cblockid_to_m0_n0_block_cluster_adaptor =
703  chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
704  cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
705 
706  return cblockid_to_m0_n0_block_cluster_adaptor;
707  }
708 
709  index_t M01_, N01_;
710  using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1));
711  UnderlyingMap underlying_map_;
712 };
713 
714 // 2D slices of row-vectors in 3D space
715 template <index_t MPerBlock,
716  index_t NPerBlock,
717  typename CGridDesc_M_N,
718  bool DeviceCTileIndexCheck = false>
720 {
721  static constexpr auto I0 = Number<0>{};
722  static constexpr auto I1 = Number<1>{};
723  static constexpr auto I2 = Number<2>{};
724  static constexpr auto I3 = Number<3>{};
725 
727 
728  __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n,
729  index_t M01 = 1,
730  index_t N01 = 1,
731  index_t KSplit = 1)
732  : c_grid_desc_m_n_(c_grid_desc_m_n),
733  M01_(M01),
734  N01_(N01),
735  KSplit_(KSplit),
736  underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit))
737  {
738  }
739 
740  __host__ __device__ constexpr index_t
741  CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
742  {
743  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
744  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
745 
746  const auto M00 = math::integer_divide_ceil(M0, M01_);
747  const auto N00 = math::integer_divide_ceil(N0, N01_);
748 
749  const index_t grid_size = M00 * M01_ * N00 * N01_ * KSplit_;
750 
751  return grid_size;
752  }
753 
754  template <typename TopIdx>
755  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
756  {
757  static_assert(TopIdx::Size() == 1);
758 
759  return underlying_map_.CalculateBottomIndex(
760  make_multi_index(idx_top[I0] % CalculateGridSize()));
761  }
762 
763  template <typename CTileIdx, typename CTileDim>
764  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
765  const CTileDim& c_tile_dim) const
766  {
767  if constexpr(DeviceCTileIndexCheck)
768  return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
769  else
770  return true;
771  }
772 
773  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
774  {
775  if constexpr(DeviceCTileIndexCheck)
776  return true; // validity check moved to kernel
777 
778  const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
779  const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
780  if(M0 % M01_ == 0 && N0 % N01_ == 0)
781  {
782  return true;
783  }
784  else
785  {
786  return false;
787  }
788  }
789 
790  private:
791  __device__ constexpr index_t CalculateGridSize() const
792  {
793  return CalculateGridSize(c_grid_desc_m_n_);
794  }
795 
796  __host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n,
797  index_t M01,
798  index_t N01,
799  index_t KSplit)
800  {
801  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
802  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
803 
804  const auto M00 = math::integer_divide_ceil(M0, M01);
805  const auto N00 = math::integer_divide_ceil(N0, N01);
806 
807  const auto ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
812  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
813  make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
814 
815  const auto c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor =
817  make_tuple(make_merge_transform(make_tuple(KSplit, M00, N00, M01, N01))),
818  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
819  make_tuple(Sequence<0>{}));
820 
821  const auto c_blockid_to_ksplit_m0_n0_block_cluster_adaptor =
822  chain_tensor_adaptors(ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
823  c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor);
824 
825  return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor;
826  }
827 
828  CGridDesc_M_N c_grid_desc_m_n_;
829  index_t M01_, N01_, KSplit_;
830  using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1));
831  UnderlyingMap underlying_map_;
832 };
833 
834 template <typename CTileIdx, typename CTileDim>
835 __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
836  const CTileDim& c_tile_dim)
837 {
838  bool is_valid = false;
839 
840  const index_t m_block = c_tile_dim[Number<0>{}];
841  const index_t n_block = c_tile_dim[Number<1>{}];
842 
843  if constexpr(CTileIdx::Size() == 2)
844  {
845  const index_t m_block_idx = c_tile_idx[Number<0>{}];
846  const index_t n_block_idx = c_tile_idx[Number<1>{}];
847  if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block)
848  {
849  is_valid = true;
850  }
851  }
852  else if constexpr(CTileIdx::Size() == 3)
853  {
854  const index_t ksplit_idx = c_tile_idx[Number<0>{}];
855  const index_t m_block_idx = c_tile_idx[Number<1>{}];
856  const index_t n_block_idx = c_tile_idx[Number<2>{}];
857  if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block)
858  {
859  is_valid = true;
860  }
861  ignore = ksplit_idx;
862  }
863 
864  return is_valid;
865 }
866 
867 // This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the
868 // workgroups assigned to a given gemm problem have top index offsetted to range [0,
869 // grid_size_per_gemm]
870 template <typename UnderlyingBlockToCTileMap>
872 {
873  using underlying_type = UnderlyingBlockToCTileMap;
874 
875  __host__ __device__ OffsettedBlockToCTileMap() = default;
876  __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
877  index_t block_start)
878  {
879  block_to_ctile_map_ = block_to_ctile_map;
880  block_start_ = block_start;
881  }
882 
883  template <typename TopIdx>
884  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
885  {
886  return block_to_ctile_map_.CalculateBottomIndex(
887  make_multi_index(idx_top[Number<0>{}] - block_start_));
888  }
889 
890  template <typename CTileIdx, typename CTileDim>
891  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
892  const CTileDim& c_tile_dim) const
893  {
894  return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
895  }
896 
897  template <typename CGridDesc_M_N>
898  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
899  {
900  return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
901  }
902 
903  template <typename CGridDesc_M_N>
904  __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
905  {
906  return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
907  }
908 
909  __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
910  {
911  return block_to_ctile_map_.CalculateGridSize(M, N);
912  }
913 
914  UnderlyingBlockToCTileMap block_to_ctile_map_;
916 };
917 // second version with 2 offsets
918 template <typename UnderlyingBlockToCTileMap>
920 {
921  using underlying_type = UnderlyingBlockToCTileMap;
922 
923  __host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map,
924  index_t group_offset,
925  index_t tile_offset)
926  : block_to_ctile_map_{block_to_ctile_map},
927  group_offset_{group_offset},
928  tile_offset_{tile_offset}
929  {
930  }
931 
932  template <typename TopIdx>
933  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
934  {
935  return block_to_ctile_map_.CalculateBottomIndex(
937  }
938 
939  template <typename CTileIdx, typename CTileDim>
940  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
941  const CTileDim& c_tile_dim) const
942  {
943  return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
944  }
945 
946  template <typename CGridDesc_M_N>
947  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
948  {
949  return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
950  }
951 
952  __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
953  {
954  return block_to_ctile_map_.CalculateGridSize(M, N);
955  }
956 
957  __device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
958  UnderlyingBlockToCTileMap block_to_ctile_map_;
961 };
962 
975 template <index_t MPerBlock, index_t NPerBlock>
977 {
978 
979  __host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default;
980 
981  __host__ __device__ constexpr auto
983  {
984  // Create 3D grid
985  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
986  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
987  return make_tuple(N0, M0, k_split);
988  }
989 
990  template <typename TopIdx>
991  __device__ constexpr auto CalculateBottomIndex(const TopIdx&) const
992  {
993  return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x);
994  }
995 
996  template <typename CTileIdx, typename CTileDim>
997  __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
998  const CTileDim& /* c_tile_dim */) const
999  {
1000  return true; // always valid provided that user gets grid size from CalculateGridSize()
1001  }
1002 
1003  template <typename CGridDesc_M_N>
1004  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
1005  {
1006  return true;
1007  }
1008 };
1009 
1011 {
1012  Atomic = 0, // sk block use atomic to do reduction
1013  Reduction, // let some workgroup responsible for doing the reduction operation
1014 };
1015 
1016 template <uint32_t MPerBlock_,
1017  uint32_t NPerBlock_,
1018  uint32_t KPerBlock_,
1020  uint32_t TileSwizzleSubM_ = 8>
1022 {
1023  static constexpr uint32_t min_k_iters_per_sk_block = 2;
1024  static constexpr uint32_t MPerBlock = MPerBlock_;
1025  static constexpr uint32_t NPerBlock = NPerBlock_;
1026  static constexpr uint32_t KPerBlock = KPerBlock_;
1027  static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
1028  static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
1029 
1030  //--------------------------------------
1031  // pass to device
1039  MDiv eqav_tiles_big; // for reduction
1040  MDiv eqav_tiles_little; // for reduction
1041 
1042  // MDiv tile_swizzle_sub_m_rem;
1043  //--------------------------------------
1044 
1045  // prefer construct on host
1047  uint32_t n,
1048  uint32_t k,
1049  uint32_t num_cu,
1050  uint32_t occupancy,
1051  uint32_t sk_blocks = 0xffffffff)
1052  {
1053  uint32_t num_tiles =
1056 
1057  // one cu can hold one wg at one time, from the whole chip's point of view
1058  // if number of wg is same as num_cu, we call it 1 dispatch
1059  // if number of wg is 2x num_cu, we call it 2 dispatches.
1060  // one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
1061  // dispatch)
1062  //
1063  uint32_t full_dispatches = num_tiles / num_cu;
1064  uint32_t full_dispatch_tiles = full_dispatches * num_cu;
1065  uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
1066 
1067  uint32_t sk_occupancy = occupancy;
1068  uint32_t dp_tiles = full_dispatch_tiles;
1069  uint32_t sk_tiles = partial_dispatche_tiles;
1070 
1071  if(full_dispatches < occupancy)
1072  {
1073  // in this case, we allocate all blocks as sk blocks
1074  // sk_occupancy = occupancy - full_dispatches;
1075  sk_occupancy = 1; // TODO: single occ seems better
1076  dp_tiles = full_dispatch_tiles;
1077  sk_tiles = partial_dispatche_tiles;
1078  }
1079  else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
1080  {
1081  // e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
1082  // occupancy = 3, full_dispatches = 5, 8, 11 ...
1083  // occupancy = 4, full_dispatches = 7, 11 ...
1084  sk_occupancy = 1; // left 1 slot for sk occupancy
1085  dp_tiles = full_dispatch_tiles;
1086  sk_tiles = partial_dispatche_tiles;
1087  }
1088  else
1089  {
1090  // others, we reduce 1 dispatch from dp, together with partial dispatch,
1091  // to construct sk dispatch
1092  sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
1093  dp_tiles = full_dispatch_tiles - num_cu;
1094  sk_tiles = partial_dispatche_tiles + num_cu;
1095  }
1096 
1097  // uint32_t dp_iters_per_block = k_iters_per_tile.get();
1098  uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
1099  uint32_t dp_num_blocks = 0;
1100 
1101  {
1102  uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
1103  uint32_t max_sk_tiles =
1104  (sk_tiles >= num_cu) ? num_cu * sk_occupancy
1105  : math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
1106 
1107  // if use dp for sk-block, how many iters do we need
1108  uint32_t dp_for_sk_iters = k_iters_per_tile.get();
1109 
1110  uint32_t best_sk_score =
1111  NumericLimits<int32_t>::Max(); // we need to find the smallest sk iters
1112  for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
1113  tentative_sk_blocks++)
1114  {
1115  uint32_t tentative_sk_iters_per_block =
1116  (sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
1117  uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
1118  uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
1119 
1120  // TODO: carefully adjust this parameter
1121  // the more sk_blocks_per_tile, the worse the overhead
1122  uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
1123  if(tentative_sk_blocks % sk_tiles != 0)
1124  {
1125  // penalty for uneven divide
1126  cross_sk_blocks_overhead +=
1127  sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
1128  }
1129 
1130  uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
1131 
1132  if(tentative_sk_score < best_sk_score)
1133  {
1134  best_sk_score = tentative_sk_score;
1135  sk_num_blocks = tentative_sk_blocks;
1136  }
1137  }
1138 
1139  if(best_sk_score >= dp_for_sk_iters)
1140  {
1141  sk_num_blocks = 0;
1142  }
1143 
1144  // give a chance to control num of sk blocks
1145  sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
1146 
1147  if(sk_num_blocks == 0)
1148  {
1149  sk_num_big_blocks = 0;
1151 
1152  dp_num_blocks = num_tiles; // all tile to be dp block
1153  dp_start_block_idx = 0;
1154  sk_total_iters = 0; // clear this tiles
1155  }
1156  else
1157  {
1158  // k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
1159  // we need to decide how many iters for each sk block
1160  // let m = k_iters_per_sk_block
1161  // some of the sk block (little) will cover m iters, some (big) will cover m+1
1162  // we have
1163  // 1) l + b = sk_blocks
1164  // 2) l * m + b * (m + 1) = sk_total_iters
1165  // => (l + b) * m + b = sk_total_iters
1166  // => sk_blocks * m + b = sk_total_iters
1167  // => b = sk_total_iters - m * sk_blocks
1168  // NOTE: big could be zero
1169  uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
1170  sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
1171  k_iters_per_big_block = k_iters_per_sk_block + 1;
1172 
1173  dp_num_blocks = dp_tiles;
1174  dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
1175  }
1176  }
1179 
1181  {
1184  eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
1185  eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
1186  }
1187 
1188 #if 0
1189  printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
1190  "sk_num_blocks:%d, "
1191  "sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
1192  "k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
1193  "sk_tiles:%u, workspace(acc float):%u\n",
1194  num_cu,
1195  occupancy,
1196  get_grid_dims().x,
1197  num_tiles,
1198  dp_tiles,
1200  sk_num_blocks,
1201  sk_total_iters,
1203  dp_iters_per_block,
1204  dp_num_blocks,
1208  get_sk_tiles(),
1209  get_workspace_size(sizeof(float)));
1210 #endif
1211  }
1212 
1213  __host__ __device__ uint32_t get_sk_total_iters() const
1214  {
1215  uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
1217  return sk_total_iters;
1218  }
1219 
1220  __host__ __device__ uint32_t get_sk_tiles() const
1221  {
1222  // tiles for sk
1223  uint32_t sk_total_iters = get_sk_total_iters();
1224  return k_iters_per_tile.div(sk_total_iters);
1225  }
1226 
1227  __host__ __device__ dim3 get_grid_dims() const
1228  {
1230  {
1231  return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
1232  }
1233  else
1234  return dim3(reduction_start_block_idx, 1, 1);
1235  }
1236 
1237  __device__ uint32_t get_block_idx() const
1238  {
1239  // TODO: swizzle block index for better locality
1240  return __builtin_amdgcn_readfirstlane(blockIdx.x);
1241  }
1242 
1243  __device__ void
1244  get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
1245  {
1246  if(block_idx < sk_num_big_blocks)
1247  {
1248  iter_start = block_idx * k_iters_per_big_block;
1249  iter_end = iter_start + k_iters_per_big_block;
1250  }
1251  else if(block_idx < sk_num_blocks)
1252  {
1253  iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
1254  (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
1255  iter_end = iter_start + (k_iters_per_big_block - 1);
1256  }
1257  else if(block_idx >= dp_start_block_idx)
1258  {
1259  uint32_t sk_total_iters = get_sk_total_iters();
1260  uint32_t dp_iters_per_block = k_iters_per_tile.get();
1261  iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
1262  iter_end = iter_start + dp_iters_per_block;
1263  }
1264  }
1265 
1267  uint32_t iter_end,
1268  uint32_t total_iter_length) const
1269  {
1270  uint32_t iter_length_mod, iter_length_quo /*unused*/;
1271  k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
1272  uint32_t current_iter_length = math::min(
1273  iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
1274  return current_iter_length;
1275  }
1276 
1277  __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
1278 
1279  __device__ void
1280  get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
1281  {
1282  k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
1283  }
1284 
1285  __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
1286  {
1287  uint32_t m_tile_idx, n_tile_idx;
1288  uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
1289  n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
1290 
1291  // swizzle tile
1293 
1294  uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
1295 
1296  const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
1298  : tile_swizzle_sub_m_rem;
1299 
1300  uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
1301  m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
1302  m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
1303 
1304  uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
1305 
1306  uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
1307 
1308  n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
1309  m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
1310  return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
1311  n_tile_idx_with_adapt);
1312  }
1313 
1314  __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
1315  {
1316  static constexpr uint32_t alignment = 128;
1317  uint32_t acc_buffer_bytes =
1318  MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
1319  return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
1320  }
1321 
1322  __host__ __device__ uint32_t get_workspace_size_for_semaphore() const
1323  {
1324  return get_sk_tiles() * sizeof(uint32_t);
1325  }
1326 
1327  __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
1328  {
1329  return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
1330  }
1331 
1332  __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
1333  const MDiv& eqav_tiles_) const
1334  {
1335  uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
1336  uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1;
1337  uint32_t quo_, rem_;
1338  eqav_tiles_.divmod(tile_idx_, quo_, rem_);
1339  return quo_ * max_eqav_tiles_ + rem_;
1340  }
1341 
1342  __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
1343  uint32_t iters_per_sk_block_) const
1344  {
1345  return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1346  1);
1347  }
1348 
1349  __host__ __device__ uint32_t get_total_acc_buffers() const
1350  {
1351  uint32_t tiles_cover_big_blocks =
1353  uint32_t tiles_cover_little_blocks =
1355 
1356  uint32_t total_intersec_big =
1357  get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big);
1358  uint32_t total_intersec_little =
1359  get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little);
1360 
1361  return sk_num_blocks + total_intersec_big + total_intersec_little;
1362  }
1363 
1365  {
1366  // TODO: from big to little
1367  uint32_t tiles_cover_big_blocks =
1369  if(tile_idx_ < tiles_cover_big_blocks)
1370  {
1371  uint32_t touched_sk_blocks =
1372  (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
1374  uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big);
1375  return touched_sk_blocks + current_intersec;
1376  }
1377  else
1378  {
1379  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1380  uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
1381  uint32_t touched_sk_blocks =
1382  (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
1383  iters_per_little_sk_block;
1384  uint32_t current_intersec =
1385  get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little);
1386  return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
1387  }
1388  }
1389 
1391  {
1392  uint32_t iters_per_big_sk_block = k_iters_per_big_block;
1393  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1394  if(block_idx_ < sk_num_big_blocks)
1395  {
1396  uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
1397  k_iters_per_tile.get() - 1);
1398  uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big);
1399  return block_idx_ + current_intersec;
1400  }
1401  else
1402  {
1403  uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
1404  uint32_t touched_tiles = k_iters_per_tile.div(
1405  block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
1406  uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little);
1407  return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
1408  }
1409  }
1410 };
1411 
1412 template <uint32_t MPerBlock_,
1413  uint32_t NPerBlock_,
1414  uint32_t KPerBlock_,
1416  uint32_t TileSwizzleSubM_ = 8,
1417  index_t GroupNum = 8,
1418  index_t M01_ = 4>
1420 {
1421  static constexpr uint32_t min_k_iters_per_sk_block = 2;
1422  static constexpr uint32_t MPerBlock = MPerBlock_;
1423  static constexpr uint32_t NPerBlock = NPerBlock_;
1424  static constexpr uint32_t KPerBlock = KPerBlock_;
1425  static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
1426 
1427  //--------------------------------------
1428  // pass to device
1436  MDiv equiv_tiles_big; // for reduction
1437  MDiv equiv_tiles_little; // for reduction
1439 
1440  // prefer construct on host
1441  __host__ __device__ BlockToCTileMap_GemmStreamK_v2(
1442  uint32_t m,
1443  uint32_t n,
1444  uint32_t k,
1445  uint32_t grid_size = 1,
1446  uint32_t streamk_sel = 1,
1448  : reduction_strategy(reduction_strategy_)
1449  {
1450 
1451  // total output tiles
1452  uint32_t num_tiles =
1455 
1456  uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
1457 
1458  // Ensure grid_size is at least 1 to avoid division by zero
1459  grid_size = math::max(grid_size, 1u);
1460 
1461  // default to regular DP GEMM if sk blocks == 0
1462  if(streamk_sel == 0)
1463  {
1464  sk_num_blocks = 0;
1465  dp_tiles = num_tiles;
1466  sk_num_big_blocks = 0;
1468 
1469  dp_num_blocks = num_tiles; // all tile to be dp block
1470  dp_start_block_idx = 0;
1471  sk_total_iters = 0; // clear this tiles
1472  }
1473  // 2-tile sk + DP GEMM
1474  else
1475  {
1476  // check if there's enough work for DP+ stream-k
1477  bool bigEnough = num_tiles > grid_size;
1478 
1479  // Select between stream-k strategies
1480  // Add safety checks to prevent zero or negative values
1481  uint32_t sk_tiles = 0;
1482  if(streamk_sel == 1) // 1 tile stream-k
1483  {
1484  sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles;
1485 
1486  // Ensure sk_tiles is at least 1
1487  sk_tiles = math::max(sk_tiles, 1u);
1488  }
1489  else if(streamk_sel == 2) // 2-tile stream-k
1490  {
1491  sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles;
1492 
1493  // Ensure sk_tiles is at least 1 but not more than num_tiles
1494  sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
1495  }
1496  else if(streamk_sel == 3) // 3-tile stream-k
1497  {
1498  sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size)
1499  : num_tiles;
1500 
1501  // Ensure sk_tiles is at least 1 but not more than num_tiles
1502  sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
1503  }
1504  else if(streamk_sel == 4) // 4-tile stream-k
1505  {
1506  sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size)
1507  : num_tiles;
1508 
1509  // Ensure sk_tiles is at least 1 but not more than num_tiles
1510  sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
1511  }
1512 
1513  sk_num_blocks = sk_tiles;
1514  // Remaining tiles are DP tiles
1515  dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
1516 
1517  sk_total_iters = k_iters_per_tile.get() * sk_tiles;
1518 
1519  // k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
1520  // we need to decide how many iters for each sk block
1521  // let m = k_iters_per_sk_block
1522  // some of the sk block (little) will cover m iters, some (big) will cover m+1
1523  // we have
1524  // 1) l + b = sk_blocks
1525  // 2) l * m + b * (m + 1) = sk_total_iters
1526  // => (l + b) * m + b = sk_total_iters
1527  // => sk_blocks * m + b = sk_total_iters
1528  // => b = sk_total_iters - m * sk_blocks
1529  // NOTE: big could be zero
1530 
1531  // Add safety check for sk_num_blocks to prevent division by zero
1532  if(sk_num_blocks > 0)
1533  {
1534  uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
1535  sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
1536  k_iters_per_big_block = k_iters_per_sk_block + 1;
1537  }
1538  else
1539  {
1540  // Fallback to default GEMM if no stream-k blocks
1541  sk_num_blocks = 0;
1542  sk_num_big_blocks = 0;
1544  dp_tiles = num_tiles;
1545  dp_num_blocks = num_tiles;
1546  dp_start_block_idx = 0;
1547  sk_total_iters = 0;
1548  }
1549 
1550  dp_num_blocks = dp_tiles;
1552  }
1553 
1555  // Using multiple blocks for parallel reduction
1557 
1559  {
1560  // Add additional safety checks
1561  if(k_iters_per_big_block > 0 && k_iters_per_tile.get() > 0)
1562  {
1564  uint32_t upper_little =
1566  equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
1567  equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
1568  }
1569  else
1570  {
1571  // Default safe values
1572  equiv_tiles_big = MDiv(1);
1573  equiv_tiles_little = MDiv(1);
1574  }
1575  }
1576  }
1577 
1578  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
1579  {
1580  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
1581  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
1582 
1583  return M0 * N0;
1584  }
1585  __host__ __device__ uint32_t get_sk_total_iters() const
1586  {
1587  uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
1589  return sk_total_iters;
1590  }
1591 
1592  __host__ __device__ uint32_t get_sk_tiles() const
1593  {
1594  // tiles for sk
1595  uint32_t sk_total_iters = get_sk_total_iters();
1596  return k_iters_per_tile.div(sk_total_iters);
1597  }
1598 
1599  __host__ __device__ index_t get_grid_dims() const
1600  {
1602  {
1603  // return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
1605  }
1606  else
1608  }
1609 
1610  __device__ uint32_t get_block_idx() const
1611  {
1612  // TODO: swizzle block index for better locality
1613  return __builtin_amdgcn_readfirstlane(blockIdx.x);
1614  }
1615 
1616  __device__ void
1617  get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
1618  {
1619  if(block_idx < sk_num_big_blocks)
1620  {
1621  iter_start = block_idx * k_iters_per_big_block;
1622  iter_end = iter_start + k_iters_per_big_block;
1623  }
1624  else if(block_idx < sk_num_blocks)
1625  {
1626  iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
1627  (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
1628  iter_end = iter_start + (k_iters_per_big_block - 1);
1629  }
1630  else if(block_idx >= dp_start_block_idx)
1631  {
1632  uint32_t sk_total_iters = get_sk_total_iters();
1633  uint32_t dp_iters_per_block = k_iters_per_tile.get();
1634  iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
1635  iter_end = iter_start + dp_iters_per_block;
1636  }
1637  }
1638 
1640  uint32_t iter_end,
1641  uint32_t total_iter_length) const
1642  {
1643  uint32_t iter_length_mod, iter_length_quo /*unused*/;
1644  k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
1645  uint32_t current_iter_length = math::min(
1646  iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
1647  return current_iter_length;
1648  }
1649 
1650  __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
1651 
1652  __device__ void
1653  get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
1654  {
1655  k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
1656  }
1657 
1658  __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
1659  {
1660  uint32_t m_tile_idx, n_tile_idx;
1661  uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
1662  n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
1663 
1664  // // swizzle tile
1666 
1667  uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
1668 
1669  const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
1671  : tile_swizzle_sub_m_rem;
1672 
1673  uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
1674  m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
1675  m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
1676 
1677  uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
1678 
1679  uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
1680 
1681  n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
1682  m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
1683  return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
1684  n_tile_idx_with_adapt);
1685  }
1686 
1687  __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
1688  {
1689  static constexpr uint32_t alignment = 128;
1690  uint32_t acc_buffer_bytes =
1691  MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
1692  return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
1693  }
1694 
1695  __host__ __device__ uint32_t get_workspace_size_for_semaphore() const
1696  {
1697  return get_sk_tiles() * sizeof(uint32_t);
1698  }
1699 
1700  __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
1701  {
1702  return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
1703  }
1704 
1705  __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
1706  const MDiv& equiv_tiles_) const
1707  {
1708  uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
1709  uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1;
1710  uint32_t quo_, rem_;
1711  equiv_tiles_.divmod(tile_idx_, quo_, rem_);
1712  return quo_ * max_equiv_tiles_ + rem_;
1713  }
1714 
1715  __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
1716  uint32_t iters_per_sk_block_) const
1717  {
1718  return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1719  1);
1720  }
1721 
1722  __host__ __device__ uint32_t get_total_acc_buffers() const
1723  {
1724  uint32_t tiles_cover_big_blocks =
1726  uint32_t tiles_cover_little_blocks =
1728 
1729  uint32_t total_intersec_big =
1730  get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big);
1731  uint32_t total_intersec_little =
1732  get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little);
1733 
1734  return sk_num_blocks + total_intersec_big + total_intersec_little;
1735  }
1736 
1738  {
1739  // TODO: from big to little
1740  uint32_t tiles_cover_big_blocks =
1742  if(tile_idx_ < tiles_cover_big_blocks)
1743  {
1744  uint32_t touched_sk_blocks =
1745  (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
1747  uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big);
1748  return touched_sk_blocks + current_intersec;
1749  }
1750  else
1751  {
1752  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1753  uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
1754  uint32_t touched_sk_blocks =
1755  (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
1756  iters_per_little_sk_block;
1757  uint32_t current_intersec =
1758  get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little);
1759  return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
1760  }
1761  }
1762 
1764  {
1765  uint32_t iters_per_big_sk_block = k_iters_per_big_block;
1766  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1767  if(block_idx_ < sk_num_big_blocks)
1768  {
1769  uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
1770  k_iters_per_tile.get() - 1);
1771  uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big);
1772  return block_idx_ + current_intersec;
1773  }
1774  else
1775  {
1776  uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
1777  uint32_t touched_tiles = k_iters_per_tile.div(
1778  block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
1779  uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little);
1780  return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
1781  }
1782  }
1783 };
1784 
1785 } // namespace ck
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
StreamKReductionStrategy
Definition: block_to_ctile_map.hpp:1011
@ Atomic
Definition: block_to_ctile_map.hpp:1012
@ Reduction
Definition: block_to_ctile_map.hpp:1013
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:43
__host__ __device__ bool DefaultValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim)
Definition: block_to_ctile_map.hpp:835
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_insert_transform(const UpperIndex &up_idx)
Definition: multi_index_transform_helper.hpp:157
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition: tensor_adaptor.hpp:245
unsigned int uint32_t
Definition: stdint.h:126
Simple tile mapping which creates 3D grid of block of threads.
Definition: block_to_ctile_map.hpp:977
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:1004
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:997
constexpr __device__ auto CalculateBottomIndex(const TopIdx &) const
Definition: block_to_ctile_map.hpp:991
__host__ constexpr __device__ auto CalculateGridSize(index_t M, index_t N, index_t k_split) const
Definition: block_to_ctile_map.hpp:982
__host__ __device__ BlockToCTileMap_3DGrid_KSplit()=default
Definition: block_to_ctile_map.hpp:1420
__host__ __device__ uint32_t get_sk_tiles() const
Definition: block_to_ctile_map.hpp:1592
MDiv k_iters_per_tile
Definition: block_to_ctile_map.hpp:1435
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
Definition: block_to_ctile_map.hpp:1700
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, const MDiv &equiv_tiles_) const
Definition: block_to_ctile_map.hpp:1705
MDiv equiv_tiles_little
Definition: block_to_ctile_map.hpp:1437
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
Definition: block_to_ctile_map.hpp:1763
uint32_t dp_start_block_idx
Definition: block_to_ctile_map.hpp:1431
static constexpr uint32_t KPerBlock
Definition: block_to_ctile_map.hpp:1424
static constexpr uint32_t NPerBlock
Definition: block_to_ctile_map.hpp:1423
MDiv2 n_tiles
Definition: block_to_ctile_map.hpp:1434
static constexpr uint32_t min_k_iters_per_sk_block
Definition: block_to_ctile_map.hpp:1421
__host__ __device__ uint32_t get_sk_total_iters() const
Definition: block_to_ctile_map.hpp:1585
__host__ __device__ uint32_t get_total_acc_buffers() const
Definition: block_to_ctile_map.hpp:1722
__host__ __device__ index_t get_grid_dims() const
Definition: block_to_ctile_map.hpp:1599
__device__ uint32_t get_tile_idx(uint32_t iter) const
Definition: block_to_ctile_map.hpp:1650
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
Definition: block_to_ctile_map.hpp:1695
__device__ void get_block_itr(uint32_t block_idx, uint32_t &iter_start, uint32_t &iter_end) const
Definition: block_to_ctile_map.hpp:1617
uint32_t k_iters_per_big_block
Definition: block_to_ctile_map.hpp:1433
uint32_t sk_num_big_blocks
Definition: block_to_ctile_map.hpp:1430
MDiv equiv_tiles_big
Definition: block_to_ctile_map.hpp:1436
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
Definition: block_to_ctile_map.hpp:1737
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
Definition: block_to_ctile_map.hpp:1658
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:1578
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
Definition: block_to_ctile_map.hpp:1687
__device__ uint32_t get_current_iter_length(uint32_t iter_start, uint32_t iter_end, uint32_t total_iter_length) const
Definition: block_to_ctile_map.hpp:1639
static constexpr uint32_t tile_swizzle_sub_m
Definition: block_to_ctile_map.hpp:1425
uint32_t reduction_start_block_idx
Definition: block_to_ctile_map.hpp:1432
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, uint32_t iters_per_sk_block_) const
Definition: block_to_ctile_map.hpp:1715
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size=1, uint32_t streamk_sel=1, StreamKReductionStrategy reduction_strategy_=StreamKReductionStrategy::Atomic)
Definition: block_to_ctile_map.hpp:1441
uint32_t sk_num_blocks
Definition: block_to_ctile_map.hpp:1429
__device__ uint32_t get_block_idx() const
Definition: block_to_ctile_map.hpp:1610
__device__ void get_tile_idx_with_offset(uint32_t iter, uint32_t &tile_idx, uint32_t &iter_offset) const
Definition: block_to_ctile_map.hpp:1653
StreamKReductionStrategy reduction_strategy
Definition: block_to_ctile_map.hpp:1438
static constexpr uint32_t MPerBlock
Definition: block_to_ctile_map.hpp:1422
Definition: block_to_ctile_map.hpp:1022
uint32_t k_iters_per_big_block
Definition: block_to_ctile_map.hpp:1036
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
Definition: block_to_ctile_map.hpp:1327
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
Definition: block_to_ctile_map.hpp:1390
__host__ __device__ uint32_t get_sk_total_iters() const
Definition: block_to_ctile_map.hpp:1213
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, uint32_t iters_per_sk_block_) const
Definition: block_to_ctile_map.hpp:1342
static constexpr uint32_t MPerBlock
Definition: block_to_ctile_map.hpp:1024
uint32_t dp_start_block_idx
Definition: block_to_ctile_map.hpp:1034
__host__ __device__ uint32_t get_sk_tiles() const
Definition: block_to_ctile_map.hpp:1220
static constexpr uint32_t KPerBlock
Definition: block_to_ctile_map.hpp:1026
__host__ __device__ uint32_t get_total_acc_buffers() const
Definition: block_to_ctile_map.hpp:1349
__device__ uint32_t get_current_iter_length(uint32_t iter_start, uint32_t iter_end, uint32_t total_iter_length) const
Definition: block_to_ctile_map.hpp:1266
static constexpr uint32_t NPerBlock
Definition: block_to_ctile_map.hpp:1025
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
Definition: block_to_ctile_map.hpp:1364
uint32_t reduction_start_block_idx
Definition: block_to_ctile_map.hpp:1035
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
Definition: block_to_ctile_map.hpp:1314
MDiv k_iters_per_tile
Definition: block_to_ctile_map.hpp:1038
__device__ void get_tile_idx_with_offset(uint32_t iter, uint32_t &tile_idx, uint32_t &iter_offset) const
Definition: block_to_ctile_map.hpp:1280
static constexpr uint32_t tile_swizzle_sub_m
Definition: block_to_ctile_map.hpp:1028
BlockToCTileMap_GemmStreamK(uint32_t m, uint32_t n, uint32_t k, uint32_t num_cu, uint32_t occupancy, uint32_t sk_blocks=0xffffffff)
Definition: block_to_ctile_map.hpp:1046
static constexpr StreamKReductionStrategy ReductionStrategy
Definition: block_to_ctile_map.hpp:1027
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
Definition: block_to_ctile_map.hpp:1285
__device__ uint32_t get_tile_idx(uint32_t iter) const
Definition: block_to_ctile_map.hpp:1277
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, const MDiv &eqav_tiles_) const
Definition: block_to_ctile_map.hpp:1332
__device__ uint32_t get_block_idx() const
Definition: block_to_ctile_map.hpp:1237
__device__ void get_block_itr(uint32_t block_idx, uint32_t &iter_start, uint32_t &iter_end) const
Definition: block_to_ctile_map.hpp:1244
MDiv eqav_tiles_little
Definition: block_to_ctile_map.hpp:1040
uint32_t sk_num_blocks
Definition: block_to_ctile_map.hpp:1032
MDiv2 n_tiles
Definition: block_to_ctile_map.hpp:1037
MDiv eqav_tiles_big
Definition: block_to_ctile_map.hpp:1039
static constexpr uint32_t min_k_iters_per_sk_block
Definition: block_to_ctile_map.hpp:1023
uint32_t sk_num_big_blocks
Definition: block_to_ctile_map.hpp:1033
__host__ __device__ dim3 get_grid_dims() const
Definition: block_to_ctile_map.hpp:1227
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
Definition: block_to_ctile_map.hpp:1322
Definition: block_to_ctile_map.hpp:271
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:298
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:283
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt()=default
static constexpr auto I1
Definition: block_to_ctile_map.hpp:273
__host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:292
static constexpr auto I0
Definition: block_to_ctile_map.hpp:272
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:384
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M, index_t N, index_t M01=8)
Definition: block_to_ctile_map.hpp:276
Definition: block_to_ctile_map.hpp:720
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:755
static constexpr auto I2
Definition: block_to_ctile_map.hpp:723
static constexpr auto I0
Definition: block_to_ctile_map.hpp:721
__host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=1, index_t N01=1, index_t KSplit=1)
Definition: block_to_ctile_map.hpp:728
__host__ BlockToCTileMap_KSplit_M00_N00_M01_N01()=default
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:773
static constexpr auto I3
Definition: block_to_ctile_map.hpp:724
static constexpr auto I1
Definition: block_to_ctile_map.hpp:722
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:764
__host__ constexpr __device__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:741
Definition: block_to_ctile_map.hpp:541
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=8, index_t KSplit=1)
Definition: block_to_ctile_map.hpp:549
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt()=default
static constexpr auto I0
Definition: block_to_ctile_map.hpp:542
static constexpr auto I1
Definition: block_to_ctile_map.hpp:543
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:594
constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:556
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:567
static constexpr auto I2
Definition: block_to_ctile_map.hpp:544
static constexpr auto I3
Definition: block_to_ctile_map.hpp:545
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:600
Definition: block_to_ctile_map.hpp:617
__host__ __device__ BlockToCTileMap_M00_N00_M01_N01()=default
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:661
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:646
__host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=1, index_t N01=1)
Definition: block_to_ctile_map.hpp:625
constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:632
static constexpr auto I0
Definition: block_to_ctile_map.hpp:618
static constexpr auto I3
Definition: block_to_ctile_map.hpp:621
static constexpr auto I1
Definition: block_to_ctile_map.hpp:619
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:652
static constexpr auto I2
Definition: block_to_ctile_map.hpp:620
__host__ constexpr __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:246
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:157
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt()=default
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt &&)=default
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt & operator=(BlockToCTileMap_M00_N0_M01Adapt &&)=default
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt &)=default
__host__ constexpr __device__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:173
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt & operator=(const BlockToCTileMap_M00_N0_M01Adapt &)=default
static constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: block_to_ctile_map.hpp:166
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=8)
Definition: block_to_ctile_map.hpp:150
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01=8)
Definition: block_to_ctile_map.hpp:138
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:179
Definition: block_to_ctile_map.hpp:261
Definition: block_to_ctile_map.hpp:24
constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:38
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:66
static constexpr auto I3
Definition: block_to_ctile_map.hpp:28
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=1)
Definition: block_to_ctile_map.hpp:32
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:51
static constexpr auto I2
Definition: block_to_ctile_map.hpp:27
static constexpr auto I0
Definition: block_to_ctile_map.hpp:25
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01()=default
static constexpr auto I1
Definition: block_to_ctile_map.hpp:26
__host__ constexpr __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:57
__host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:451
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N &c_grid_desc_m_n, index_t N01=8)
Definition: block_to_ctile_map.hpp:429
static constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: block_to_ctile_map.hpp:445
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt & operator=(const BlockToCTileMap_N00_M0_N01Adapt &)=default
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:525
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt & operator=(BlockToCTileMap_N00_M0_N01Adapt &&)=default
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:457
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const BlockToCTileMap_N00_M0_N01Adapt &)=default
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(BlockToCTileMap_N00_M0_N01Adapt &&)=default
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01=8)
Definition: block_to_ctile_map.hpp:418
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:436
Definition: block_to_ctile_map.hpp:399
Definition: magic_division.hpp:204
__host__ __device__ void divmod(uint32_t dividend_, uint32_t divisor_, uint32_t &quotient_, uint32_t &remainder_) const
Definition: magic_division.hpp:226
Definition: magic_division.hpp:162
__host__ __device__ uint32_t get() const
Definition: magic_division.hpp:200
__host__ __device__ void divmod(uint32_t dividend_, uint32_t &quotient_, uint32_t &remainder_) const
Definition: magic_division.hpp:194
__host__ __device__ uint32_t div(uint32_t dividend_) const
Definition: magic_division.hpp:188
__host__ static constexpr __device__ T Max()
Definition: numeric_limits.hpp:311
Definition: block_to_ctile_map.hpp:920
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:940
index_t tile_offset_
Definition: block_to_ctile_map.hpp:960
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:947
UnderlyingBlockToCTileMap block_to_ctile_map_
Definition: block_to_ctile_map.hpp:958
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:933
__host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map, index_t group_offset, index_t tile_offset)
Definition: block_to_ctile_map.hpp:923
UnderlyingBlockToCTileMap underlying_type
Definition: block_to_ctile_map.hpp:921
index_t group_offset_
Definition: block_to_ctile_map.hpp:959
__device__ void UpdateTileOffset(index_t offset)
Definition: block_to_ctile_map.hpp:957
__host__ constexpr __device__ index_t CalculateGridSize(index_t M, index_t N) const
Definition: block_to_ctile_map.hpp:952
Definition: block_to_ctile_map.hpp:872
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:891
__host__ constexpr __device__ index_t CalculateGridSize(index_t M, index_t N) const
Definition: block_to_ctile_map.hpp:909
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:898
constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:904
index_t block_start_
Definition: block_to_ctile_map.hpp:915
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:884
__host__ __device__ OffsettedBlockToCTileMap()=default
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start)
Definition: block_to_ctile_map.hpp:876
UnderlyingBlockToCTileMap underlying_type
Definition: block_to_ctile_map.hpp:873
UnderlyingBlockToCTileMap block_to_ctile_map_
Definition: block_to_ctile_map.hpp:914
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20