/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File
streamk_gemm_kernel.hpp
Go to the documentation of this file.
1 // Copyright © Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 namespace reboot {
12 
21 {
22  CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
23  const void* b_ptr_,
24  void* c_ptr_,
25  index_t M_,
26  index_t N_,
27  index_t K_,
28  index_t stride_A_,
29  index_t stride_B_,
30  index_t stride_C_,
31  StreamKReductionStrategy reduction_strategy_)
32  : UniversalGemmHostArgs<>({a_ptr_},
33  {b_ptr_},
34  {/*ds_ptr*/},
35  c_ptr_,
36  /*k_batch_ =*/1,
37  M_,
38  N_,
39  K_,
40  {stride_A_},
41  {stride_B_},
42  {/*stride_Ds_*/},
43  stride_C_),
44  reduction_strategy{reduction_strategy_}
45  {
46  }
47 
49 };
50 
55 // The main kernel functions are the operator() functions. There is one for Persistent
56 // and one for Non-Persistent data parallel sections of the Stream-K algorithm.
57 //
58 // Both the Non-Persistent and Persistent kernels make use of `BaseGemm()` and
59 // `StreamKGemm()`. `BaseGemm()` computes offsets into the A,B,C tensors, then calls
60 // `RunGemm()` which runs the GEMM pipeline and epilogue. `StreamKGemm()` performs the
61 // main Stream-K algorithm. Each iteration of the Stream-K loop calls `BaseGemm()`.
62 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
64 {
69 
72 
73  using TilePartitioner = TilePartitioner_;
74  using GemmPipeline = GemmPipeline_;
75  using EpiloguePipeline = EpiloguePipeline_;
76 
77  static_assert(
78  TilePartitioner::PERSISTENT == PersistentDP,
79  "Persistent flag from TilePartitioner must match Persistent flag from UniversalGemm.");
80 
82  using ALayout = typename GemmPipeline::ALayout;
83  using BLayout = typename GemmPipeline::BLayout;
84  using CLayout = typename GemmPipeline::CLayout;
85 
87  using ADataType = typename GemmPipeline::ADataType;
88  using BDataType = typename GemmPipeline::BDataType;
89  using CDataType = typename EpiloguePipeline::ODataType;
90 
91  template <typename T>
93 
95  static_assert(!is_tuple_v<ALayout> && !is_tuple_v<ADataType>,
96  "ALayout and ADataType must be scalars.");
97 
99  static_assert(!is_tuple_v<BLayout> && !is_tuple_v<BDataType>,
100  "BLayout and BDataType must be scalars.");
101 
103  static_assert(!is_tuple_v<CLayout> && !is_tuple_v<CDataType>,
104  "CLayout and CDataType must be scalars.");
105 
107  {
108  StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
109  : UniversalGemmKernelArgs{host_args.as_ptr,
110  host_args.bs_ptr,
111  host_args.ds_ptr,
112  host_args.e_ptr,
113  host_args.M,
114  host_args.N,
115  host_args.K,
116  host_args.stride_As,
117  host_args.stride_Bs,
118  host_args.stride_Ds,
119  host_args.stride_E,
120  host_args.k_batch},
122  // The workspace pointer is set to nullptr because we must first
123  // instantiate the TilePartitioner to get the necessary size
124  workspace_ptr{nullptr},
125  tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
126 
127  {
128  }
129 
138  };
139 
142 
143  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
144  {
145  // clang-format off
146  using P_ = GemmPipeline;
147  using WarpTile = typename P_::BlockGemmShape::WarpTile;
148 
149  return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
150  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
151  concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
152  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
153  concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
154  // clang-format on
155  }
156 
159  CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
160  {
161  return tile_partitioner.grid_size();
162  }
163 
168  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
169  {
171  }
172 
173  CK_TILE_HOST static constexpr auto BlockSize() -> dim3
174  {
176  }
177 
186  int num_cu = NumCU(),
187  int occupancy = Occupancy())
188  {
189  const index_t grid = num_cu * occupancy;
190 
191  return StreamKKernelArgs{host_args, grid};
192  }
193 
194  template <bool UseDefaultScheduler = true>
195  CK_TILE_DEVICE static void
196  RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
197  const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
198  const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
199  CDataType* c_ptr,
200  void* smem_ptr_0,
201  const typename UniversalGemmKernel::KernelArgs& kargs,
202  const index_t num_loop,
203  const index_t block_idx_m,
204  const index_t block_idx_n,
205  const index_t k_size)
206  {
207  // Create Gemm tensor views, pad views and tile windows
208  const auto& gemm_tensor_views_tuple =
209  UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
210  as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
211 
212  const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
213  auto gemm_tile_windows =
214  UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
215 
216  // Run GEMM cooperatively by whole workgroup.
217  const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
218  const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
219  const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
220 
221  // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
222  // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
223  // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
224  // tail_num.
225  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
226  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
227 
228  const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
229  bs_block_window[UniversalGemmKernel::I0],
230  num_loop,
231  has_hot_loop,
232  tail_num,
233  smem_ptr_0);
234 
235  if(UseDefaultScheduler || (get_warp_id() == 0))
236  {
237  // Run Epilogue Pipeline
238  auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
239 
240  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
241  }
242  }
243 
245  {
247  {
248  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
249  {
250  CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
251  }
252  return false;
253  }
255  }
256 
260  {
261  return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
262  }
263 
266  CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
267  {
268  kargs.workspace_ptr = workspace_ptr;
269  }
270 
282  index_t tile_idx,
283  index_t num_loop,
284  index_t i_k_a,
285  index_t i_k_b,
286  index_t k_size,
287  void* smem_ptr_0) const
288  {
289  const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx);
290  index_t i_m = c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
291  index_t i_n = c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
292 
293  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
294  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
295  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
296 
297  // Run the GEMM pipeline and Epilogue.
298  RunGemm(
299  {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
300  }
301 
309  CK_TILE_DEVICE void
310  StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
311  {
312  index_t iter_start, iter_end;
313  kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx);
314 
315  while(iter_start < iter_end)
316  {
317  // Get the 1D tile index in the C tensor that this workgroup will work in for this
318  // iteration of the loop.
319  index_t tile_idx =
320  amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start));
321 
322  // Get the start and end boundaries for the current tile.
323  index_t tile_iter_start, tile_iter_end;
324  kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx);
325 
326  // Get the start and end iteration within the current tile for the workgroup.
327  index_t local_iter_start = amd_wave_read_first_lane(
328  kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start));
329  index_t local_iter_end =
330  amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end(
331  tile_iter_start, iter_end, tile_iter_end));
332 
333  // Get the iteration length.
334  index_t num_loop_sk = local_iter_end - local_iter_start;
335 
336  // Determine the total size along the K dimension the workgroup is using in this
337  // iteration (used to construct tensor views).
338  index_t k_size = num_loop_sk * TilePartitioner::KPerBlock;
339 
340  // Get the K offsets for the A and B tensors
341  auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
342  local_iter_start, kargs.stride_As[0], kargs.stride_Bs[0]);
343 
344  if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Atomic)
345  {
346  BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
347  }
348  else
349  {
350  // TODO: Apply reduction logic.
351  }
352 
353  // Prepare for next Stream-K loop iteration.
354  iter_start = tile_iter_end;
355  block_sync_lds();
356  }
357  }
358 
366  template <bool U = PersistentDP>
367  CK_TILE_DEVICE typename std::enable_if_t<!U> operator()(StreamKKernelArgs kargs) const
368  {
369  // Allocate LDS
370  __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
371 
372  index_t block_idx = ck_tile::get_block_1d_id();
373  index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
374  index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas();
375  bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas();
376 
377  // Check if at the data parallel section
378  if(is_dp_ctas)
379  {
380  BaseGemm(kargs, block_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
381  }
382  else
383  {
384  // Stream-K
385  StreamKGemm(kargs, block_idx - dp_ctas, smem_ptr_0);
386  }
387  }
388 
397  template <bool U = PersistentDP>
398  CK_TILE_DEVICE typename std::enable_if_t<U> operator()(StreamKKernelArgs kargs) const
399  {
400  // Allocate LDS
401  __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
402 
403  index_t block_idx = ck_tile::get_block_1d_id();
404  index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
405 
406  // Data-parallel section
407  for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
408  tile_idx += kargs.tile_partitioner.get_grid())
409  {
410  BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
411  }
412 
413  // Stream-K section
414  StreamKGemm(kargs, block_idx, smem_ptr_0);
415  }
416 
417  private:
424  template <typename ALayout, typename BLayout>
426  GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
427  {
428  index_t stride_offset_a;
429  index_t stride_offset_b;
430  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
431  {
432  stride_offset_a = stride_a;
433  }
434  else
435  {
436  stride_offset_a = 1;
437  }
438 
439  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
440  {
441  stride_offset_b = stride_b;
442  }
443  else
444  {
445  stride_offset_b = 1;
446  }
447 
448  index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
449 
450  return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
451  }
452 
453  CK_TILE_HOST static int NumCU()
454  {
455  hipDeviceProp_t dev_prop;
456  hipDevice_t dev;
457  hip_check_error(hipGetDevice(&dev));
458  hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
459  int num_cu = dev_prop.multiProcessorCount;
460 
461  return num_cu;
462  }
463 
468  CK_TILE_HOST static int Occupancy()
469  {
470  int occupancy;
471 
472  // Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
473  constexpr int min_block_per_cu = 1;
474  const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
475 
477  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
478 
479  return occupancy;
480  }
481 };
482 } // namespace reboot
483 
492 {
493  CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
494  const void* b_ptr_,
495  void* c_ptr_,
496  index_t M_,
497  index_t N_,
498  index_t K_,
499  index_t stride_A_,
500  index_t stride_B_,
501  index_t stride_C_,
502  StreamKReductionStrategy reduction_strategy_,
503  uint32_t num_sk_blocks_ = 0xffffffff)
504  : UniversalGemmHostArgs<>({a_ptr_},
505  {b_ptr_},
506  {/*ds_ptr*/},
507  c_ptr_,
508  /*k_batch_ =*/1,
509  M_,
510  N_,
511  K_,
512  {stride_A_},
513  {stride_B_},
514  {/*stride_Ds_*/},
515  stride_C_),
516  reduction_strategy{reduction_strategy_},
517  num_sk_blocks{num_sk_blocks_}
518  {
519  }
520 
523 };
524 
525 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
527 {
532 
534 
538 
543 
548 
550  static_assert(!is_detected<is_tuple, ALayout>::value &&
552  "ALayout and ADataType must be scalars.");
553 
555  static_assert(!is_detected<is_tuple, BLayout>::value &&
557  "BLayout and BDataType must be scalars.");
558 
560  static_assert(!is_detected<is_tuple, CLayout>::value &&
562  "CLayout and CDataType must be scalars.");
563 
565  {
576  };
577 
580 
581  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
582  {
583  // clang-format off
584  using P_ = GemmPipeline;
585  using WarpTile = typename P_::BlockGemmShape::WarpTile;
586 
587  return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
588  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
589  concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
590  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
591  concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
592  // clang-format on
593  }
594 
597  CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
598  {
599  return tile_partitioner.GridSize();
600  }
601 
606  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
607  {
609  }
610 
611  CK_TILE_HOST static constexpr auto BlockSize() -> dim3
612  {
614  }
615 
624  int num_cu = NumCU(),
625  int occupancy = Occupancy())
626  {
627  return StreamKKernelArgs{{host_args.as_ptr,
628  host_args.bs_ptr,
629  host_args.ds_ptr,
630  host_args.e_ptr,
631  host_args.M,
632  host_args.N,
633  host_args.K,
634  host_args.stride_As,
635  host_args.stride_Bs,
636  host_args.stride_Ds,
637  host_args.stride_E,
638  host_args.k_batch},
639  host_args.reduction_strategy,
640  host_args.num_sk_blocks,
641  // The workspace pointer is set to nullptr because we must first
642  // instantiate the TilePartitioner to get the necessary size
643  /*workspace_ptr =*/nullptr,
644  TilePartitioner{static_cast<uint32_t>(host_args.M),
645  static_cast<uint32_t>(host_args.N),
646  static_cast<uint32_t>(host_args.K),
647  static_cast<uint32_t>(num_cu),
648  static_cast<uint32_t>(occupancy),
649  host_args.num_sk_blocks}};
650  }
651 
652  template <bool UseDefaultScheduler = true>
653  CK_TILE_DEVICE static void
654  RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
655  const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
656  const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
657  CDataType* c_ptr,
658  void* smem_ptr_0,
659  const typename UniversalGemmKernel::KernelArgs& kargs,
660  const index_t num_loop,
661  const index_t block_idx_m,
662  const index_t block_idx_n,
663  const index_t k_size)
664  {
665  // Create Gemm tensor views, pad views and tile windows
666  const auto& gemm_tensor_views_tuple =
667  UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
668  as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
669 
670  const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
671  auto gemm_tile_windows =
672  UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
673 
674  // Run GEMM cooperatively by whole workgroup.
675  const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
676  const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
677  const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
678 
679  // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
680  // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
681  // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
682  // tail_num.
683  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
684  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
685 
686  const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
687  bs_block_window[UniversalGemmKernel::I0],
688  num_loop,
689  has_hot_loop,
690  tail_num,
691  smem_ptr_0);
692 
693  if(UseDefaultScheduler || (get_warp_id() == 0))
694  {
695  // Run Epilogue Pipeline
696  auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
697 
698  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
699  }
700  }
701 
703  {
705  {
706  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
707  {
708  CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
709  }
710  return false;
711  }
713  }
714 
718  {
719  // For reduction, we need to determine the amount of device space for acculumation
720  // results and semaphores.
722  {
723  return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
724  }
725 
726  // Otherwise, no additional space is needed since blocks atomically store their results.
727  return 0;
728  }
729 
732  CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
733  {
734  kargs.workspace_ptr = workspace_ptr;
735  }
736 
739  {
740  // Allocate LDS
741  __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
742 
743  uint32_t block_idx = ck_tile::get_block_1d_id();
744 
745  bool is_padding_block =
746  amd_wave_read_first_lane(block_idx >= kargs.tile_partitioner.sk_num_blocks &&
747  block_idx < kargs.tile_partitioner.dp_start_block_idx);
748 
749  // Padding blocks make it such that the DP blocks are aligned with the number of CUs; they
750  // should not partake in the GEMM
751  if(is_padding_block)
752  return;
753 
754  // Determine the K offset of the first and final macro tile in the A and B tensors along the
755  // K dimension.
756  uint32_t iter_start, iter_end;
757  kargs.tile_partitioner.GetBlockItr(block_idx, iter_start, iter_end);
758 
759  // Main Stream-K loop
760  while(true)
761  {
762  // Determine the number of macro tiles in A and B this WG is resposible for in the
763  // current C macro tile.
764  uint32_t current_iter_length = amd_wave_read_first_lane(
765  kargs.tile_partitioner.GetCurrentIterLength(iter_start, iter_end));
766 
767  // Determine the 1D tile_idx and the iter_offset for this WG.
768  // The tile_idx is the 1D macro tile index in the C tensor.
769  // The iter_offset is the starting macro tile index in the K dimension for the WG in the
770  // current iteration of the while loop.
771  uint32_t tile_idx, iter_offset;
772  kargs.tile_partitioner.GetTileIdxWithOffset(iter_start, tile_idx, iter_offset);
773 
774  // Get the 2D tile index in the C tensor for this WG using the 1D index (i.e. tile_idx)
775  auto spatial_idx = kargs.tile_partitioner.GetOutputTileIndex(tile_idx);
776 
777  // Get the offsets in A, B, C tensors.
778  index_t i_m = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I0] *
779  TilePartitioner::MPerBlock);
780  index_t i_n = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I1] *
781  TilePartitioner::NPerBlock);
782  auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
783  static_cast<index_t>(iter_offset), kargs.stride_As[0], kargs.stride_Bs[0]);
784 
785  // Determine the total size along the K dimension the WG is using in this iteration
786  // (used to construct tensor views).
787  index_t k_size = static_cast<index_t>(current_iter_length * TilePartitioner::KPerBlock);
788 
789  // Update pointer offsets for A, B, and C.
790  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
791  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
792  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
793 
794  // Run the GEMM pipeline and Epilogue.
795  RunGemm({a_ptr},
796  {b_ptr},
797  {/*ds_ptr*/},
798  c_ptr,
799  smem_ptr_0,
800  kargs,
801  current_iter_length,
802  i_m,
803  i_n,
804  k_size);
805 
806  // Prepare for next Stream-K loop iteration.
807  iter_start += current_iter_length;
808  if(iter_end <= iter_start)
809  break;
810  block_sync_lds();
811  }
812  }
813 
814  private:
821  template <typename ALayout, typename BLayout>
823  GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
824  {
825  index_t stride_offset_a;
826  index_t stride_offset_b;
827  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
828  {
829  stride_offset_a = stride_a;
830  }
831  else
832  {
833  stride_offset_a = 1;
834  }
835 
836  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
837  {
838  stride_offset_b = stride_b;
839  }
840  else
841  {
842  stride_offset_b = 1;
843  }
844 
845  index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
846 
847  return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
848  }
849 
850  CK_TILE_HOST static int NumCU()
851  {
852  hipDeviceProp_t dev_prop;
853  hipDevice_t dev;
854  hip_check_error(hipGetDevice(&dev));
855  hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
856  int num_cu = dev_prop.multiProcessorCount;
857 
858  return num_cu;
859  }
860 
865  CK_TILE_HOST static int Occupancy()
866  {
867  int occupancy;
868 
869  // Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
870  constexpr int min_block_per_cu = 1;
871  const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
872 
874  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
875 
876  return occupancy;
877  }
878 };
879 
880 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:245
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
StreamKReductionStrategy
Definition: streamk_common.hpp:10
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
@ Atomic
Definition: block_to_ctile_map.hpp:1012
@ Reduction
Definition: block_to_ctile_map.hpp:1013
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
unsigned int uint32_t
Definition: stdint.h:126
The Stream K GEMM kernel host arguments.
Definition: streamk_gemm_kernel.hpp:492
uint32_t num_sk_blocks
Definition: streamk_gemm_kernel.hpp:522
ck_tile::StreamKReductionStrategy reduction_strategy
Definition: streamk_gemm_kernel.hpp:521
CK_TILE_HOST StreamKHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_, StreamKReductionStrategy reduction_strategy_, uint32_t num_sk_blocks_=0xffffffff)
Definition: streamk_gemm_kernel.hpp:493
ALayout and ADataType are expected to be scalars, not a tuple.
Definition: streamk_gemm_kernel.hpp:565
StreamKReductionStrategy reduction_strategy
The strategy used by work groups to compute final results in C tensor.
Definition: streamk_gemm_kernel.hpp:567
uint32_t num_sk_blocks
The number of stream k blocks.
Definition: streamk_gemm_kernel.hpp:569
void * workspace_ptr
A pointer to a buffer in device memory for accumulating partial via reduction strategy.
Definition: streamk_gemm_kernel.hpp:572
TilePartitioner tile_partitioner
An instance of the TilePartioner class for assisting with mapping workgroups to the C tensor.
Definition: streamk_gemm_kernel.hpp:575
Definition: streamk_gemm_kernel.hpp:527
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: streamk_gemm_kernel.hpp:531
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Specify the layout configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:540
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:545
static CK_TILE_HOST auto GridSize(const TilePartitioner &tile_partitioner) -> dim3
Compute the grid size for the Stream K kernel using the tile_partitioner.
Definition: streamk_gemm_kernel.hpp:597
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: streamk_gemm_kernel.hpp:541
static CK_TILE_HOST StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs &host_args, int num_cu=NumCU(), int occupancy=Occupancy())
Constructs kernel arguments for the Stream-K kernel.
Definition: streamk_gemm_kernel.hpp:623
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: streamk_gemm_kernel.hpp:547
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: streamk_gemm_kernel.hpp:535
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: streamk_gemm_kernel.hpp:537
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, UniversalGemmKernel::NumATensor > &as_ptr, const std::array< const BDataType *, UniversalGemmKernel::NumBTensor > &bs_ptr, const std::array< const void *, UniversalGemmKernel::NumDTensor > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const typename UniversalGemmKernel::KernelArgs &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t k_size)
Definition: streamk_gemm_kernel.hpp:654
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: streamk_gemm_kernel.hpp:606
static CK_TILE_HOST void SetWorkSpacePointer(StreamKKernelArgs &kargs, void *workspace_ptr)
Sets the kargs' current workspace_ptr to the given workspace_ptr.
Definition: streamk_gemm_kernel.hpp:732
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: streamk_gemm_kernel.hpp:546
static constexpr index_t kBlockSize
Definition: streamk_gemm_kernel.hpp:533
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: streamk_gemm_kernel.hpp:536
static CK_TILE_HOST const std::string GetName()
Definition: streamk_gemm_kernel.hpp:581
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: streamk_gemm_kernel.hpp:611
static CK_TILE_HOST bool IsSupportedArgument(const StreamKKernelArgs &kargs)
Definition: streamk_gemm_kernel.hpp:702
CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel, performing the main Stream-K loop.
Definition: streamk_gemm_kernel.hpp:738
static CK_TILE_HOST uint32_t GetWorkSpaceSize(const StreamKKernelArgs &kargs)
Computes the buffer size needed to store accumulation results for Stream K.
Definition: streamk_gemm_kernel.hpp:717
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: streamk_gemm_kernel.hpp:542
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:72
index_t K
Definition: universal_gemm_kernel.hpp:70
void * e_ptr
Definition: universal_gemm_kernel.hpp:65
index_t M
Definition: universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:71
index_t N
Definition: universal_gemm_kernel.hpp:69
index_t stride_E
Definition: universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:61
index_t k_batch
Definition: universal_gemm_kernel.hpp:80
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
std::array< index_t, NumBTensor > stride_Bs
The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
Definition: universal_gemm_kernel.hpp:106
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:92
std::array< index_t, NumATensor > stride_As
The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
Definition: universal_gemm_kernel.hpp:103
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
index_t k_batch
Definition: universal_gemm_kernel.hpp:113
index_t N
GEMM's N dimension size.
Definition: universal_gemm_kernel.hpp:98
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: universal_gemm_kernel.hpp:112
index_t K
GEMM's K dimension size.
Definition: universal_gemm_kernel.hpp:100
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: universal_gemm_kernel.hpp:109
index_t M
GEMM's M dimension size.
Definition: universal_gemm_kernel.hpp:96
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:238
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:853
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:239
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:754
static constexpr bool PersistentKernel
Definition: universal_gemm_kernel.hpp:217
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:237
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:290
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:278
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:236
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:319
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202
Definition: integral_constant.hpp:13
The Stream K GEMM kernel host arguments.
Definition: streamk_gemm_kernel.hpp:21
ck_tile::StreamKReductionStrategy reduction_strategy
Definition: streamk_gemm_kernel.hpp:48
CK_TILE_HOST StreamKHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_, StreamKReductionStrategy reduction_strategy_)
Definition: streamk_gemm_kernel.hpp:22
ALayout and ADataType are expected to be scalars, not a tuple.
Definition: streamk_gemm_kernel.hpp:107
TilePartitioner tile_partitioner
An instance of the TilePartioner class for assisting with mapping workgroups to the C tensor.
Definition: streamk_gemm_kernel.hpp:137
StreamKReductionStrategy reduction_strategy
The strategy used by work groups to compute final results in C tensor.
Definition: streamk_gemm_kernel.hpp:131
void * workspace_ptr
A pointer to a buffer in device memory for accumulating partial via reduction strategy.
Definition: streamk_gemm_kernel.hpp:134
StreamKKernelArgs(const StreamKHostArgs &host_args, index_t grid)
Definition: streamk_gemm_kernel.hpp:108
The Stream K GEMM kernel class.
Definition: streamk_gemm_kernel.hpp:64
typename GemmPipeline::ALayout ALayout
Specify the layout configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:82
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: streamk_gemm_kernel.hpp:173
CK_TILE_DEVICE std::enable_if_t< U > operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel with persistent DP.
Definition: streamk_gemm_kernel.hpp:398
static CK_TILE_HOST auto GridSize(const TilePartitioner &tile_partitioner) -> dim3
Compute the grid size for the Stream K kernel using the tile_partitioner.
Definition: streamk_gemm_kernel.hpp:159
static constexpr bool is_tuple_v
Definition: streamk_gemm_kernel.hpp:92
static constexpr bool PersistentDP
Definition: streamk_gemm_kernel.hpp:71
EpiloguePipeline_ EpiloguePipeline
Definition: streamk_gemm_kernel.hpp:75
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: streamk_gemm_kernel.hpp:168
typename GemmPipeline::BDataType BDataType
Definition: streamk_gemm_kernel.hpp:88
CK_TILE_DEVICE void BaseGemm(StreamKKernelArgs &kargs, index_t tile_idx, index_t num_loop, index_t i_k_a, index_t i_k_b, index_t k_size, void *smem_ptr_0) const
Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue.
Definition: streamk_gemm_kernel.hpp:281
static CK_TILE_HOST const std::string GetName()
Definition: streamk_gemm_kernel.hpp:143
static CK_TILE_HOST bool IsSupportedArgument(const StreamKKernelArgs &kargs)
Definition: streamk_gemm_kernel.hpp:244
static constexpr index_t kBlockSize
Definition: streamk_gemm_kernel.hpp:70
CK_TILE_DEVICE std::enable_if_t<!U > operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel with non-persistent DP.
Definition: streamk_gemm_kernel.hpp:367
static CK_TILE_HOST uint32_t GetWorkSpaceSize(const StreamKKernelArgs &kargs)
Computes the buffer size needed to store accumulation results for Stream K.
Definition: streamk_gemm_kernel.hpp:259
static CK_TILE_HOST StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs &host_args, int num_cu=NumCU(), int occupancy=Occupancy())
Constructs kernel arguments for the Stream-K kernel.
Definition: streamk_gemm_kernel.hpp:185
CK_TILE_DEVICE void StreamKGemm(StreamKKernelArgs &kargs, index_t cta_idx, void *smem_ptr_0) const
Runs the main Stream-K algorithm.
Definition: streamk_gemm_kernel.hpp:310
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, UniversalGemmKernel::NumATensor > &as_ptr, const std::array< const BDataType *, UniversalGemmKernel::NumBTensor > &bs_ptr, const std::array< const void *, UniversalGemmKernel::NumDTensor > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const typename UniversalGemmKernel::KernelArgs &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t k_size)
Definition: streamk_gemm_kernel.hpp:196
static CK_TILE_HOST void SetWorkSpacePointer(StreamKKernelArgs &kargs, void *workspace_ptr)
Sets the kargs' current workspace_ptr to the given workspace_ptr.
Definition: streamk_gemm_kernel.hpp:266
typename GemmPipeline::ADataType ADataType
Specify the data type configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:87
typename GemmPipeline::BLayout BLayout
Definition: streamk_gemm_kernel.hpp:83
typename EpiloguePipeline::ODataType CDataType
Definition: streamk_gemm_kernel.hpp:89
GemmPipeline_ GemmPipeline
Definition: streamk_gemm_kernel.hpp:74
TilePartitioner_ TilePartitioner
Definition: streamk_gemm_kernel.hpp:73
typename GemmPipeline::CLayout CLayout
Definition: streamk_gemm_kernel.hpp:84
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: streamk_gemm_kernel.hpp:68
Definition: stream_config.hpp:30
Definition: tuple.hpp:192
#define CK_TILE_ENV(name)
Definition: env.hpp:145