/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/algorithm/static_encoding_pattern.hpp File Reference#
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/print.hpp"
Go to the source code of this file.
Namespaces | |
ck_tile | |
Enumerations | |
enum class | ck_tile::tile_distribution_pattern { ck_tile::thread_raked , ck_tile::warp_raked , ck_tile::block_raked } |
Enumeration describing static tile distribution patterns. More... | |
Functions | |
constexpr const char * | ck_tile::tile_distribution_pattern_to_string (tile_distribution_pattern pattern) |
template<index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize, tile_distribution_pattern DistributionPattern, index_t NumWaveGroups> | |
CK_TILE_HOST_DEVICE void | ck_tile::print (const tile_distribution_encoding_pattern_2d< BlockSize, YPerTile, XPerTile, VecSize, DistributionPattern, NumWaveGroups > &) |
Detailed Description
We're defining the data access pattern for a 2D window (XPerTile
by YPerTile
) for BlockSize
threads in a thread block. X dimension is considered contiguous in memory, so a single instruction can access several adjacent and properly aligned elements (vector); the access pattern along X tile dimension is parameterized only by the suggested vector size VecSize
. We can't access more than MaxVecSize = TileElementsPerThread = TileSize / BlockSize
elements with a single memory access, so the actual vector size along the X dimension is X0 = min(MaxVecSize, VecSize)
. This leaves X1 = XPerTile / X0
threads per tile in X dimension. X1 is also the number of threads per warp in X dimension, that is, X dimension is not split between warps, and each warp accesses X dimension entirely, and there is no iteration in X dimension. The tuple <X0, X1> defines the X-axis access pattern. This part is common between the 2D distribution patterns.
What's different between the different 2D distribution patterns, is the Y axis access pattern. There are 3 components in this access pattern; (1) number of Y-axis elements (rows) per warp for a single instruction access, (2) number of warps per thread block, (3) number of iterations to cover the entire Y axis.
The raked here represents how data is partitioned across different processing granularity. It represents howe we are going to access the data in thread, warp, or blocked in contiguous region. From below, the qualifier for 'raked' is the part of warp/thread hierarchy in the split of Y tile dimension where the iteration happens, meaning, the iteration can be logically inserted as a tile dimension in 3 ways, (1) after thread -> thread-raked, (2) between warp and thread -> warp-raked, (3) before warp -> block-raked
Thread raked
Y0 is the number of warps, which we can get from the equation Y0 * WarpSize == BlockSize
Y1 is the number of rows accessed by a warp within a single iteration, compute it from the equation Y0 * X1 == WarpSize
Y2 is the number of iterations to cover the tile, compute it from the equation Y0 * Y1 * Y2 == YPerTile
Warp raked
Y0 is the number of warps, we can get it in the same way as for thread-raked pattern, Y0 * WarpSize == BlockSize
Y1 is the number of iterations to cover the tile, Y0 * Y1 * Y2 == YPerTile
. Compute Y2 from the equation below Y2 is the number of rows accessed by a warp in a single iteration, Y2 * X1 == WarpSize
Block raked
Y0 is the number of iterations to cover the tile, Y0 * Y1 * Y2 == YPerTile
. Compute Y1 and Y2 from the equations below Y1 is the number of warps, Y1 * WarpSize == BlockSize
Y2 is the number of rows accessed by a warp in a single iteration, Y2 * X1 == WarpSize
In all cases, the tuple <Y0, Y1, Y2> defines the Y-axis access pattern.
Selection When we are selecting, Thread-raked is used in element-wise operation because it is the Thread-major memory order. Warp-raked is used in matrix multiplication because the vectorization is in warp level. Block-raked is used mostly for the reduction process, where will reduce the block in global atomic level.