/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/algorithm/static_encoding_pattern.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/algorithm/static_encoding_pattern.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/algorithm/static_encoding_pattern.hpp Source File
static_encoding_pattern.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
71 #pragma once
72 
74 #include "ck_tile/core/config.hpp"
81 
82 namespace ck_tile {
83 
89 {
99  warp_raked,
105 };
106 
108 {
109 };
110 
123 template <index_t BlockSize,
124  index_t YPerTile,
125  index_t XPerTile,
126  index_t VecSize,
127  tile_distribution_pattern DistributionPattern,
128  index_t NumWaveGroups = 1>
130 {
131 };
132 
133 // Thread raked
134 template <index_t BlockSize,
135  index_t YPerTile,
136  index_t XPerTile,
137  index_t VecSize,
138  index_t NumWaveGroups>
140  YPerTile,
141  XPerTile,
142  VecSize,
144  NumWaveGroups>
146 {
147  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
148  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
149  static constexpr index_t warp_size = get_warp_size();
150  static constexpr index_t num_warps = BlockSize / get_warp_size();
151  static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
152  static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
153  static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
154 
155  // # of rows in Y dim accessed by single wavefront in one iteration
156  static constexpr index_t Y1 = warp_size / X0;
157  static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
158 
159  static constexpr index_t Y0 = num_warps / NumWaveGroups;
160  // YPerWarp = YPerTile / Y0;
161  // Y2 = YPerWarp / Y1;
162  static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
163 
164  static_assert(X0 * Y1 * Y0 * NumWaveGroups == BlockSize,
165  "X0 * warp_ys * Y0 must cover whole workgroup!");
166  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
167 
169  {
170  if constexpr(NumWaveGroups != 1)
171  {
176  tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
178  sequence<1, 1>>{}); // -> <Y2, X1>
179  }
180  else
181  {
186  tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
188  sequence<2, 1>>{}); // -> <Y2, X1>
189  }
190  }
191 
193  {
194  if constexpr(NumWaveGroups != 1)
195  {
200  tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
202  sequence<1, 1>>{}); // -> <X1, Y2>
203  }
204  else
205  {
210  tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
212  sequence<1, 2>>{}); // -> <X1, Y2>
213  }
214  }
215 };
216 
217 // Warp raked
218 template <index_t BlockSize,
219  index_t YPerTile,
220  index_t XPerTile,
221  index_t VecSize,
222  index_t NumWaveGroups>
224  YPerTile,
225  XPerTile,
226  VecSize,
228  NumWaveGroups>
230 {
231 
232  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
233  static constexpr index_t warp_size = get_warp_size();
234  static constexpr index_t num_warps = BlockSize / get_warp_size();
235  static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
236  static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
237  static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
238 
239  static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
240  static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
241 
242  static constexpr index_t Y0 = num_warps;
243  static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
244 
245  static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
246  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
247 
249  {
254  tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
256  sequence<1, 1>>{}); // -> <Y1, X1>
257  }
258 
260  {
265  tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
267  sequence<1, 1>>{}); // -> <X1, Y1>
268  }
269 };
270 
271 // Block raked
272 template <index_t BlockSize,
273  index_t YPerTile,
274  index_t XPerTile,
275  index_t VecSize,
276  index_t NumWaveGroups>
278  YPerTile,
279  XPerTile,
280  VecSize,
282  NumWaveGroups>
284 {
285 
286  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
287  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
288  static constexpr index_t warp_size = get_warp_size();
289  static constexpr index_t num_warps = BlockSize / get_warp_size();
290  static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
291  static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
292  static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
293  static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
294  static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
295  static constexpr index_t Y1 = num_warps;
296  static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
297  static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
298  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
299 
301  {
306  tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
308  sequence<0, 1>>{}); // -> <Y0, X1>
309  }
310 
312  {
317  tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
319  sequence<1, 0>>{}); // -> <X1, Y0>
320  }
321 };
322 
323 // Helper function to convert enum to string
325 {
326  switch(pattern)
327  {
328  case tile_distribution_pattern::thread_raked: return "thread_raked";
329  case tile_distribution_pattern::warp_raked: return "warp_raked";
330  case tile_distribution_pattern::block_raked: return "block_raked";
331  default: return "unknown";
332  }
333 }
334 
335 template <index_t BlockSize,
336  index_t YPerTile,
337  index_t XPerTile,
338  index_t VecSize,
339  tile_distribution_pattern DistributionPattern,
340  index_t NumWaveGroups>
342  YPerTile,
343  XPerTile,
344  VecSize,
345  DistributionPattern,
346  NumWaveGroups>&)
347 {
348  using PatternType = tile_distribution_encoding_pattern_2d<BlockSize,
349  YPerTile,
350  XPerTile,
351  VecSize,
352  DistributionPattern,
353  NumWaveGroups>;
354 
355  printf("tile_distribution_encoding_pattern_2d<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
356  "VecSize:%d, %s>: ",
357  BlockSize,
358  YPerTile,
359  XPerTile,
360  VecSize,
361  tile_distribution_pattern_to_string(DistributionPattern));
362  printf("{<Y0, Y1, Y2>: <%d, %d, %d>, <X0, X1>: <%d, %d>}\n",
363  PatternType::Y0,
364  PatternType::Y1,
365  PatternType::Y2,
366  PatternType::X0,
367  PatternType::X1);
368 }
369 
370 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr const char * tile_distribution_pattern_to_string(tile_distribution_pattern pattern)
Definition: static_encoding_pattern.hpp:324
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d< BlockSize, YPerTile, XPerTile, VecSize, DistributionPattern, NumWaveGroups > &)
Definition: static_encoding_pattern.hpp:341
tile_distribution_pattern
Enumeration describing static tile distribution patterns.
Definition: static_encoding_pattern.hpp:89
@ block_raked
Block raked pattern - aka linear.
@ thread_raked
Thread raked pattern.
@ warp_raked
Warp raked pattern.
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: sequence.hpp:49
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: static_encoding_pattern.hpp:168
static constexpr CK_TILE_HOST_DEVICE auto make_shuffled_2d_static_tile_distribution()
Definition: static_encoding_pattern.hpp:192
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: static_encoding_pattern.hpp:248
static constexpr CK_TILE_HOST_DEVICE auto make_shuffled_2d_static_tile_distribution()
Definition: static_encoding_pattern.hpp:259
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: static_encoding_pattern.hpp:300
static constexpr CK_TILE_HOST_DEVICE auto make_shuffled_2d_static_tile_distribution()
Definition: static_encoding_pattern.hpp:311
Class creating 2D static tile distribution with different load/store patterns.
Definition: static_encoding_pattern.hpp:130
Definition: static_encoding_pattern.hpp:108
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192