/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/algorithm/static_encoding_pattern.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/algorithm/static_encoding_pattern.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/algorithm/static_encoding_pattern.hpp Source File
static_encoding_pattern.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
71 #pragma once
72 
74 #include "ck_tile/core/config.hpp"
81 
82 namespace ck_tile {
83 
89 {
99  warp_raked,
104  block_raked,
105 };
106 
108 {
109 };
110 
123 template <index_t BlockSize,
124  index_t YPerTile,
125  index_t XPerTile,
126  index_t VecSize,
127  tile_distribution_pattern DistributionPattern,
128  index_t NumWaveGroups = 1>
130 {
131 };
132 
133 // Thread raked
134 template <index_t BlockSize,
135  index_t YPerTile,
136  index_t XPerTile,
137  index_t VecSize,
138  index_t NumWaveGroups>
140  YPerTile,
141  XPerTile,
142  VecSize,
144  NumWaveGroups>
146 {
147 
148  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
149  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
150  static constexpr index_t warp_size = get_warp_size();
151  static constexpr index_t num_warps = BlockSize / get_warp_size();
152  static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
153  static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
154  static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
155 
156  // # of rows in Y dim accessed by single wavefront in one iteration
157  static constexpr index_t Y1 = warp_size / X0;
158  static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
159 
160  static constexpr index_t Y0 = num_warps / NumWaveGroups;
161  // YPerWarp = YPerTile / Y0;
162  // Y2 = YPerWarp / Y1;
163  static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
164 
165  static_assert(X0 * Y1 * Y0 * NumWaveGroups == BlockSize,
166  "X0 * warp_ys * Y0 must cover whole workgroup!");
167  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
168 
170  {
171  if constexpr(NumWaveGroups != 1)
172  {
177  tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
179  sequence<1, 1>>{}); // -> <Y2, X1>
180  }
181  else
182  {
187  tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
189  sequence<2, 1>>{}); // -> <Y2, X1>
190  }
191  }
192 
194  {
195  if constexpr(NumWaveGroups != 1)
196  {
201  tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
203  sequence<1, 1>>{}); // -> <X1, Y2>
204  }
205  else
206  {
211  tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
213  sequence<1, 2>>{}); // -> <X1, Y2>
214  }
215  }
216 };
217 
218 // Warp raked
219 template <index_t BlockSize,
220  index_t YPerTile,
221  index_t XPerTile,
222  index_t VecSize,
223  index_t NumWaveGroups>
225  YPerTile,
226  XPerTile,
227  VecSize,
229  NumWaveGroups>
231 {
232 
233  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
234  static constexpr index_t warp_size = get_warp_size();
235  static constexpr index_t num_warps = BlockSize / get_warp_size();
236  static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
237  static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
238  static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
239 
240  static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
241  static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
242 
243  static constexpr index_t Y0 = num_warps;
244  static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
245 
246  static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
247  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
248 
250  {
255  tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
257  sequence<1, 1>>{}); // -> <Y1, X1>
258  }
259 
261  {
266  tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
268  sequence<1, 1>>{}); // -> <X1, Y1>
269  }
270 };
271 
272 // Block raked
273 template <index_t BlockSize,
274  index_t YPerTile,
275  index_t XPerTile,
276  index_t VecSize,
277  index_t NumWaveGroups>
279  YPerTile,
280  XPerTile,
281  VecSize,
283  NumWaveGroups>
285 {
286 
287  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
288  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
289  static constexpr index_t warp_size = get_warp_size();
290  static constexpr index_t num_warps = BlockSize / get_warp_size();
291  static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
292  static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
293  static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
294  static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
295  static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
296  static constexpr index_t Y1 = num_warps;
297  static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
298  static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
299  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
300 
302  {
307  tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
309  sequence<0, 1>>{}); // -> <Y0, X1>
310  }
311 
313  {
318  tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
320  sequence<1, 0>>{}); // -> <X1, Y0>
321  }
322 };
323 
324 // Helper function to convert enum to string
326 {
327  switch(pattern)
328  {
329  case tile_distribution_pattern::thread_raked: return "thread_raked";
330  case tile_distribution_pattern::warp_raked: return "warp_raked";
331  case tile_distribution_pattern::block_raked: return "block_raked";
332  default: return "unknown";
333  }
334 }
335 
336 template <index_t BlockSize,
337  index_t YPerTile,
338  index_t XPerTile,
339  index_t VecSize,
340  tile_distribution_pattern DistributionPattern,
341  index_t NumWaveGroups>
343  YPerTile,
344  XPerTile,
345  VecSize,
346  DistributionPattern,
347  NumWaveGroups>&)
348 {
349  using PatternType = tile_distribution_encoding_pattern_2d<BlockSize,
350  YPerTile,
351  XPerTile,
352  VecSize,
353  DistributionPattern,
354  NumWaveGroups>;
355 
356  printf("tile_distribution_encoding_pattern_2d<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
357  "VecSize:%d, %s>: ",
358  BlockSize,
359  YPerTile,
360  XPerTile,
361  VecSize,
362  tile_distribution_pattern_to_string(DistributionPattern));
363  printf("{<Y0, Y1, Y2>: <%d, %d, %d>, <X0, X1>: <%d, %d>}\n",
364  PatternType::Y0,
365  PatternType::Y1,
366  PatternType::Y2,
367  PatternType::X0,
368  PatternType::X1);
369 }
370 
371 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr const char * tile_distribution_pattern_to_string(tile_distribution_pattern pattern)
Definition: static_encoding_pattern.hpp:325
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d< BlockSize, YPerTile, XPerTile, VecSize, DistributionPattern, NumWaveGroups > &)
Definition: static_encoding_pattern.hpp:342
tile_distribution_pattern
Enumeration describing static tile distribution patterns.
Definition: static_encoding_pattern.hpp:89
@ block_raked
Block raked pattern - aka linear.
@ thread_raked
Thread raked pattern.
@ warp_raked
Warp raked pattern.
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
Definition: sequence.hpp:49
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: static_encoding_pattern.hpp:169
static constexpr CK_TILE_HOST_DEVICE auto make_shuffled_2d_static_tile_distribution()
Definition: static_encoding_pattern.hpp:193
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: static_encoding_pattern.hpp:249
static constexpr CK_TILE_HOST_DEVICE auto make_shuffled_2d_static_tile_distribution()
Definition: static_encoding_pattern.hpp:260
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: static_encoding_pattern.hpp:301
static constexpr CK_TILE_HOST_DEVICE auto make_shuffled_2d_static_tile_distribution()
Definition: static_encoding_pattern.hpp:312
Class creating 2D static tile distribution with different load/store patterns.
Definition: static_encoding_pattern.hpp:130
Definition: static_encoding_pattern.hpp:108
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192