/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File
streamk_gemm_kernel.hpp
Go to the documentation of this file.
1 // Copyright © Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 
20 {
21  CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
22  const void* b_ptr_,
23  void* c_ptr_,
24  index_t M_,
25  index_t N_,
26  index_t K_,
27  index_t stride_A_,
28  index_t stride_B_,
29  index_t stride_C_,
30  StreamKReductionStrategy reduction_strategy_,
31  uint32_t num_sk_blocks_ = 0xffffffff)
32  : UniversalGemmHostArgs<>({a_ptr_},
33  {b_ptr_},
34  {/*ds_ptr*/},
35  c_ptr_,
36  /*k_batch_ =*/1,
37  M_,
38  N_,
39  K_,
40  {stride_A_},
41  {stride_B_},
42  {/*stride_Ds_*/},
43  stride_C_),
44  reduction_strategy{reduction_strategy_},
45  num_sk_blocks{num_sk_blocks_}
46  {
47  }
48 
51 };
52 
53 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
55 {
60 
62 
66 
71 
76 
80  "ALayout and ADataType must be scalars.");
81 
85  "BLayout and BDataType must be scalars.");
86 
90  "CLayout and CDataType must be scalars.");
91 
93  {
104  };
105 
108 
109  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
110  {
111  // clang-format off
112  using P_ = GemmPipeline;
113  using WarpTile = typename P_::BlockGemmShape::WarpTile;
114 
115  return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
116  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
117  concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
118  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
119  concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
120  // clang-format on
121  }
122 
125  CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
126  {
127  return tile_partitioner.GridSize();
128  }
129 
134  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
135  {
137  }
138 
139  CK_TILE_HOST static constexpr auto BlockSize() -> dim3
140  {
142  }
143 
152  int num_cu = NumCU(),
153  int occupancy = Occupancy())
154  {
155  return StreamKKernelArgs{{host_args.as_ptr,
156  host_args.bs_ptr,
157  host_args.ds_ptr,
158  host_args.e_ptr,
159  host_args.M,
160  host_args.N,
161  host_args.K,
162  host_args.stride_As,
163  host_args.stride_Bs,
164  host_args.stride_Ds,
165  host_args.stride_E,
166  host_args.k_batch},
167  host_args.reduction_strategy,
168  host_args.num_sk_blocks,
169  // The workspace pointer is set to nullptr because we must first
170  // instantiate the TilePartitioner to get the necessary size
171  /*workspace_ptr =*/nullptr,
172  TilePartitioner{static_cast<uint32_t>(host_args.M),
173  static_cast<uint32_t>(host_args.N),
174  static_cast<uint32_t>(host_args.K),
175  static_cast<uint32_t>(num_cu),
176  static_cast<uint32_t>(occupancy),
177  host_args.num_sk_blocks}};
178  }
179 
180  template <bool UseDefaultScheduler = true>
181  CK_TILE_DEVICE static void
182  RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
183  const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
184  const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
185  CDataType* c_ptr,
186  void* smem_ptr_0,
187  const typename UniversalGemmKernel::KernelArgs& kargs,
188  const index_t num_loop,
189  const index_t block_idx_m,
190  const index_t block_idx_n,
191  const index_t k_size)
192  {
193  // Create Gemm tensor views, pad views and tile windows
194  const auto& gemm_tensor_views_tuple =
195  UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
196  as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
197 
198  const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
199  auto gemm_tile_windows =
200  UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
201 
202  // Run GEMM cooperatively by whole workgroup.
203  const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
204  const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
205  const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
206 
207  // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
208  // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
209  // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
210  // tail_num.
211  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
212  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
213 
214  const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
215  bs_block_window[UniversalGemmKernel::I0],
216  num_loop,
217  has_hot_loop,
218  tail_num,
219  smem_ptr_0);
220 
221  if(UseDefaultScheduler || (get_warp_id() == 0))
222  {
223  // Run Epilogue Pipeline
224  auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
225 
226  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
227  }
228  }
229 
231  {
233  {
234  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
235  {
236  CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
237  }
238  return false;
239  }
241  }
242 
246  {
247  // For reduction, we need to determine the amount of device space for acculumation
248  // results and semaphores.
250  {
251  return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
252  }
253 
254  // Otherwise, no additional space is needed since blocks atomically store their results.
255  return 0;
256  }
257 
260  CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
261  {
262  kargs.workspace_ptr = workspace_ptr;
263  }
264 
267  {
268  // Allocate LDS
269  __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
270 
271  uint32_t block_idx = ck_tile::get_block_1d_id();
272 
273  bool is_padding_block =
274  amd_wave_read_first_lane(block_idx >= kargs.tile_partitioner.sk_num_blocks &&
275  block_idx < kargs.tile_partitioner.dp_start_block_idx);
276 
277  // Padding blocks make it such that the DP blocks are aligned with the number of CUs; they
278  // should not partake in the GEMM
279  if(is_padding_block)
280  return;
281 
282  // Determine the K offset of the first and final macro tile in the A and B tensors along the
283  // K dimension.
284  uint32_t iter_start, iter_end;
285  kargs.tile_partitioner.GetBlockItr(block_idx, iter_start, iter_end);
286 
287  // Main Stream-K loop
288  while(true)
289  {
290  // Determine the number of macro tiles in A and B this WG is resposible for in the
291  // current C macro tile.
292  uint32_t current_iter_length = amd_wave_read_first_lane(
293  kargs.tile_partitioner.GetCurrentIterLength(iter_start, iter_end));
294 
295  // Determine the 1D tile_idx and the iter_offset for this WG.
296  // The tile_idx is the 1D macro tile index in the C tensor.
297  // The iter_offset is the starting macro tile index in the K dimension for the WG in the
298  // current iteration of the while loop.
299  uint32_t tile_idx, iter_offset;
300  kargs.tile_partitioner.GetTileIdxWithOffset(iter_start, tile_idx, iter_offset);
301 
302  // Get the 2D tile index in the C tensor for this WG using the 1D index (i.e. tile_idx)
303  auto spatial_idx = kargs.tile_partitioner.GetOutputTileIndex(tile_idx);
304 
305  // Get the offsets in A, B, C tensors.
306  index_t i_m = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I0] *
307  TilePartitioner::MPerBlock);
308  index_t i_n = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I1] *
309  TilePartitioner::NPerBlock);
310  index_t i_k = static_cast<index_t>(iter_offset) * TilePartitioner::KPerBlock;
311 
312  // Determine the total size along the K dimension the WG is using in this iteration
313  // (used to construct tensor views).
314  index_t k_size = static_cast<index_t>(current_iter_length * TilePartitioner::KPerBlock);
315 
316  // Update pointer offsets for A, B, and C.
317  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k;
318  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k;
319  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
320 
321  // Run the GEMM pipeline and Epilogue.
322  RunGemm({a_ptr},
323  {b_ptr},
324  {/*ds_ptr*/},
325  c_ptr,
326  smem_ptr_0,
327  kargs,
328  current_iter_length,
329  i_m,
330  i_n,
331  k_size);
332 
333  // Prepare for next Stream-K loop iteration.
334  iter_start += current_iter_length;
335  if(iter_end <= iter_start)
336  break;
337  block_sync_lds();
338  }
339  }
340 
341  private:
342  CK_TILE_HOST static int NumCU()
343  {
344  hipDeviceProp_t dev_prop;
345  hipDevice_t dev;
346  hip_check_error(hipGetDevice(&dev));
347  hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
348  int num_cu = dev_prop.multiProcessorCount;
349 
350  return num_cu;
351  }
352 
357  CK_TILE_HOST static int Occupancy()
358  {
359  int occupancy;
360 
361  // Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
362  constexpr int min_block_per_cu = 1;
363  const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
364 
366  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
367 
368  return occupancy;
369  }
370 };
371 
372 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:192
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
StreamKReductionStrategy
Definition: streamk_common.hpp:10
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
@ Reduction
Definition: block_to_ctile_map.hpp:1013
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
unsigned int uint32_t
Definition: stdint.h:126
The Stream K GEMM kernel host arguments.
Definition: streamk_gemm_kernel.hpp:20
uint32_t num_sk_blocks
Definition: streamk_gemm_kernel.hpp:50
ck_tile::StreamKReductionStrategy reduction_strategy
Definition: streamk_gemm_kernel.hpp:49
CK_TILE_HOST StreamKHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_, StreamKReductionStrategy reduction_strategy_, uint32_t num_sk_blocks_=0xffffffff)
Definition: streamk_gemm_kernel.hpp:21
ALayout and ADataType are expected to be scalars, not a tuple.
Definition: streamk_gemm_kernel.hpp:93
StreamKReductionStrategy reduction_strategy
The strategy used by work groups to compute final results in C tensor.
Definition: streamk_gemm_kernel.hpp:95
uint32_t num_sk_blocks
The number of stream k blocks.
Definition: streamk_gemm_kernel.hpp:97
void * workspace_ptr
A pointer to a buffer in device memory for accumulating partial via reduction strategy.
Definition: streamk_gemm_kernel.hpp:100
TilePartitioner tile_partitioner
An instance of the TilePartioner class for assisting with mapping workgroups to the C tensor.
Definition: streamk_gemm_kernel.hpp:103
Definition: streamk_gemm_kernel.hpp:55
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: streamk_gemm_kernel.hpp:59
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Specify the layout configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:68
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:73
static CK_TILE_HOST auto GridSize(const TilePartitioner &tile_partitioner) -> dim3
Compute the grid size for the Stream K kernel using the tile_partitioner.
Definition: streamk_gemm_kernel.hpp:125
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: streamk_gemm_kernel.hpp:69
static CK_TILE_HOST StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs &host_args, int num_cu=NumCU(), int occupancy=Occupancy())
Constructs kernel arguments for the Stream-K kernel.
Definition: streamk_gemm_kernel.hpp:151
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: streamk_gemm_kernel.hpp:75
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: streamk_gemm_kernel.hpp:63
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: streamk_gemm_kernel.hpp:65
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, UniversalGemmKernel::NumATensor > &as_ptr, const std::array< const BDataType *, UniversalGemmKernel::NumBTensor > &bs_ptr, const std::array< const void *, UniversalGemmKernel::NumDTensor > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const typename UniversalGemmKernel::KernelArgs &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t k_size)
Definition: streamk_gemm_kernel.hpp:182
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: streamk_gemm_kernel.hpp:134
static CK_TILE_HOST void SetWorkSpacePointer(StreamKKernelArgs &kargs, void *workspace_ptr)
Sets the kargs' current workspace_ptr to the given workspace_ptr.
Definition: streamk_gemm_kernel.hpp:260
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: streamk_gemm_kernel.hpp:74
static constexpr index_t kBlockSize
Definition: streamk_gemm_kernel.hpp:61
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: streamk_gemm_kernel.hpp:64
static CK_TILE_HOST const std::string GetName()
Definition: streamk_gemm_kernel.hpp:109
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: streamk_gemm_kernel.hpp:139
static CK_TILE_HOST bool IsSupportedArgument(const StreamKKernelArgs &kargs)
Definition: streamk_gemm_kernel.hpp:230
CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel, performing the main Stream-K loop.
Definition: streamk_gemm_kernel.hpp:266
static CK_TILE_HOST uint32_t GetWorkSpaceSize(const StreamKKernelArgs &kargs)
Computes the buffer size needed to store accumulation results for Stream K.
Definition: streamk_gemm_kernel.hpp:245
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: streamk_gemm_kernel.hpp:70
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:72
index_t K
Definition: universal_gemm_kernel.hpp:70
void * e_ptr
Definition: universal_gemm_kernel.hpp:65
index_t M
Definition: universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:71
index_t N
Definition: universal_gemm_kernel.hpp:69
index_t stride_E
Definition: universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:61
index_t k_batch
Definition: universal_gemm_kernel.hpp:80
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:238
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:853
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:239
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:754
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:237
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:290
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:278
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:236
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:319
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202
Definition: integral_constant.hpp:13
Definition: stream_config.hpp:30
#define CK_TILE_ENV(name)
Definition: env.hpp:145