include/ck_tile/core/algorithm/static_encoding_pattern.hpp Source File

include/ck_tile/core/algorithm/static_encoding_pattern.hpp Source File#

Composable Kernel: include/ck_tile/core/algorithm/static_encoding_pattern.hpp Source File
static_encoding_pattern.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
13 
14 namespace ck_tile {
15 
21 {
31  warp_raked,
37 };
38 
40 {
41 };
42 
55 template <index_t BlockSize,
56  index_t YPerTile,
57  index_t XPerTile,
58  index_t VecSize,
59  tile_distribution_pattern DistributionPattern>
61 {
62 };
63 
64 // Thread raked
65 template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
67  YPerTile,
68  XPerTile,
69  VecSize,
72 {
73 
74  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
75  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
76  static constexpr index_t warp_size = get_warp_size();
77  static constexpr index_t num_warps = BlockSize / get_warp_size();
78  static constexpr index_t X1 = VecSize;
79  static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
80 
81  // # of rows in Y dim accessed by single wavefront in one iteration
82  static constexpr index_t Y1 = warp_size / X0;
83  static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
84 
85  static constexpr index_t Y0 = num_warps;
86  // YPerWarp = YPerTile / Y0;
87  // Y2 = YPerWarp / Y1;
88  static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
89 
90  static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!");
91  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
92 
94  {
101  sequence<2, 1>>{});
102  }
103 
105  {
112  sequence<1, 2>>{});
113  }
114 };
115 
116 // Warp raked
117 template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
119  YPerTile,
120  XPerTile,
121  VecSize,
124 {
125 
126  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
127  static constexpr index_t warp_size = get_warp_size();
128  static constexpr index_t num_warps = BlockSize / get_warp_size();
129  static constexpr index_t X1 = VecSize;
130  static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
131 
132  static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
133  static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
134 
135  static constexpr index_t Y0 = num_warps;
136  static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
137 
138  static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
139  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
140 
142  {
149  sequence<1, 1>>{});
150  }
151 
153  {
160  sequence<1, 1>>{});
161  }
162 };
163 
164 // Block raked
165 template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
167  YPerTile,
168  XPerTile,
169  VecSize,
172 {
173 
174  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
175  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
176  static constexpr index_t warp_size = get_warp_size();
177  static constexpr index_t num_warps = BlockSize / get_warp_size();
178  static constexpr index_t X1 = VecSize;
179  static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
180  static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
181  static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
182  static constexpr index_t Y1 = num_warps;
183  static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
184  static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
185  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
186 
188  {
195  sequence<0, 1>>{});
196  }
197 
199  {
206  sequence<1, 0>>{});
207  }
208 };
209 
210 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
int32_t index_t
Definition: integer.hpp:9
tile_distribution_pattern
Enumeration describing static tile distribution patterns.
Definition: static_encoding_pattern.hpp:21
@ block_raked
Block raked pattern - aka linear.
@ thread_raked
Thread raked pattern.
@ warp_raked
Warp raked pattern.
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
static constexpr CK_TILE_HOST_DEVICE auto MakeShuffled2DStaticTileDistribution()
Definition: static_encoding_pattern.hpp:198
static constexpr CK_TILE_HOST_DEVICE auto Make2DStaticTileDistribution()
Definition: static_encoding_pattern.hpp:187
static constexpr CK_TILE_HOST_DEVICE auto MakeShuffled2DStaticTileDistribution()
Definition: static_encoding_pattern.hpp:152
static constexpr CK_TILE_HOST_DEVICE auto Make2DStaticTileDistribution()
Definition: static_encoding_pattern.hpp:141
static constexpr CK_TILE_HOST_DEVICE auto MakeShuffled2DStaticTileDistribution()
Definition: static_encoding_pattern.hpp:104
static constexpr CK_TILE_HOST_DEVICE auto Make2DStaticTileDistribution()
Definition: static_encoding_pattern.hpp:93
Class creating 2D static tile distribution with different load/store patterns.
Definition: static_encoding_pattern.hpp:61
Definition: static_encoding_pattern.hpp:40
Definition: sequence.hpp:52
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192