/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp Source File
streamk_gemm_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 #include "ck_tile/ops/common.hpp"
10 
11 namespace ck_tile {
12 
23 {
24  CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
25  const void* b_ptr_,
26  void* c_ptr_,
27  index_t M_,
28  index_t N_,
29  index_t K_,
30  index_t stride_A_,
31  index_t stride_B_,
32  index_t stride_C_)
33  : UniversalGemmHostArgs<>({a_ptr_},
34  {b_ptr_},
35  {/*ds_ptr*/},
36  c_ptr_,
37  /*k_batch_ =*/1,
38  M_,
39  N_,
40  K_,
41  {stride_A_},
42  {stride_B_},
43  {/*stride_Ds_*/},
44  stride_C_)
45  {
46  }
47 };
48 
62 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
64 {
71 
74 
75  using TilePartitioner = TilePartitioner_;
76  using GemmPipeline = GemmPipeline_;
77  using EpiloguePipeline = EpiloguePipeline_;
78 
79  static_assert(
80  TilePartitioner::PERSISTENT == PersistentDP,
81  "Persistent flag from TilePartitioner must match Persistent flag from UniversalGemm.");
82 
86  using ALayout = typename GemmPipeline::ALayout;
87  using BLayout = typename GemmPipeline::BLayout;
88  using CLayout = typename GemmPipeline::CLayout;
89 
93  using ADataType = typename GemmPipeline::ADataType;
94  using BDataType = typename GemmPipeline::BDataType;
95  using CDataType = typename EpiloguePipeline::ODataType;
96  using AccDataType = typename EpiloguePipeline::AccDataType;
97 
98  template <typename T>
103  static_assert(!is_tuple_v<ALayout> && !is_tuple_v<ADataType>,
104  "ALayout and ADataType must be scalars.");
105 
109  static_assert(!is_tuple_v<BLayout> && !is_tuple_v<BDataType>,
110  "BLayout and BDataType must be scalars.");
111 
115  static_assert(!is_tuple_v<CLayout> && !is_tuple_v<CDataType>,
116  "CLayout and CDataType must be scalars.");
117 
119  {
120  StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
121  : UniversalGemmKernelArgs{host_args.as_ptr,
122  host_args.bs_ptr,
123  host_args.ds_ptr,
124  host_args.e_ptr,
125  host_args.M,
126  host_args.N,
127  host_args.K,
128  host_args.stride_As,
129  host_args.stride_Bs,
130  host_args.stride_Ds,
131  host_args.stride_E,
132  host_args.k_batch},
133  // The workspace pointer is set to nullptr because we must first
134  // instantiate the TilePartitioner to get the necessary size
135  workspace_ptr{nullptr},
136  tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
137 
138  {
139  }
150  };
151 
154 
155  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
156  {
157  // clang-format off
158  using P_ = GemmPipeline;
159  using WarpTile = typename P_::BlockGemmShape::WarpTile;
160 
161  return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
162  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
163  concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
164  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
165  concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
166  // clang-format on
167  }
168 
173  CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
174  {
175  return tile_partitioner.grid_size();
176  }
177 
184  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
185  {
187  }
188 
189  CK_TILE_HOST static constexpr auto BlockSize() -> dim3
190  {
192  }
193 
204  int num_cu = NumCU(),
205  int occupancy = Occupancy())
206  {
207  const index_t grid = num_cu * occupancy;
208 
209  return StreamKKernelArgs{host_args, grid};
210  }
211 
212  template <bool UseDefaultScheduler = true>
213  CK_TILE_DEVICE static void
214  RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
215  const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
216  const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
217  CDataType* c_ptr,
218  void* smem_ptr_0,
219  const typename UniversalGemmKernel::KernelArgs& kargs,
220  const index_t num_loop,
221  const index_t block_idx_m,
222  const index_t block_idx_n,
223  const index_t k_size)
224  {
225  // Create block windows using specialized methods
226  const auto& as_block_window =
227  UniversalGemmKernel::MakeABlockWindows(as_ptr, kargs, k_size, block_idx_m);
228  const auto& bs_block_window =
229  UniversalGemmKernel::MakeBBlockWindows(bs_ptr, kargs, k_size, block_idx_n);
230  const auto& ds_block_window =
231  UniversalGemmKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
232 
233  // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
234  // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
235  // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
236  // tail_num.
237  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
238  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
239 
240  // Run GEMM cooperatively by whole workgroup.
241  const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
242  bs_block_window[UniversalGemmKernel::I0],
243  num_loop,
244  has_hot_loop,
245  tail_num,
246  smem_ptr_0);
247 
248  if(UseDefaultScheduler || (get_warp_id() == 0))
249  {
250  // Run Epilogue Pipeline
251  auto c_block_window =
252  UniversalGemmKernel::template MakeCBlockWindows<TilePartitioner::MemoryOperation>(
253  c_ptr, kargs, block_idx_m, block_idx_n);
254 
255  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
256  }
257  }
258 
260  {
262  }
263 
269  {
270  return kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType));
271  }
276  CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
277  {
278  kargs.workspace_ptr = workspace_ptr;
279  }
280 
294  index_t tile_idx,
295  index_t num_loop,
296  index_t i_k_a,
297  index_t i_k_b,
298  index_t k_size,
299  void* smem_ptr_0) const
300  {
301  const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx);
302  index_t i_m = c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
303  index_t i_n = c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
304 
305  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
306  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
307  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
308 
309  // Run the GEMM pipeline and Epilogue.
310  RunGemm(
311  {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
312  }
313 
322  index_t cta_idx) const
323  {
324  auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
325  index_t offset = cta_idx * sizeof(index_t);
326 
327  asm volatile("s_mov_b32 m0, %2\n\t"
328  // Depending on the architecture, the GLC flag will bypass the approproriate
329  // cache level(s) to ensure the write is visible to other workgroups. See the
330  // appropriate ISA for details about the GLC modifier.
331  "s_store_dword %0, %1, %2 glc\n\t"
332  "s_waitcnt lgkmcnt(0)" // Wait for the store to complete
333  :
334  : "s"(1), "s"(sk_flags_ptr), "s"(offset)
335  : "memory");
336  }
337 
346  {
347  auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
348  index_t result;
349  index_t offset = cta_idx * sizeof(index_t);
350 
351  do
352  {
353  asm volatile("s_mov_b32 m0, %2\n\t"
354  // Depending on the architecture, the GLC flag will bypass the
355  // approproriate cache level(s) to avoid reading stale flags. See the
356  // appropriate ISA for details about the GLC modifier.
357  "s_load_dword %0, %1, %2 glc\n\t"
358  "s_waitcnt lgkmcnt(0)" // Wait for the load to complete
359  : "=s"(result)
360  : "s"(sk_flags_ptr), "s"(offset)
361  : "memory");
362  } while(result != 1);
363  }
364 
372  template <typename OAccTile>
373  CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
374  const OAccTile& in_block_tile) const
375  {
376  using BlockType = remove_cvref_t<decltype(in_out_block_tile)>;
377  constexpr auto o_spans = BlockType::get_distributed_spans();
378  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
379  sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
380  constexpr auto idx = make_tuple(idx0, idx1);
381  in_out_block_tile(idx) = in_out_block_tile[idx] + in_block_tile[idx];
382  });
383  });
384  }
385 
395  template <typename DataType, typename OAccTileDist>
397  index_t cta_idx,
398  const OAccTileDist& c_block_tile_dist) const
399  {
400  const auto c_block_tile_buffer_size =
401  TilePartitioner::MPerBlock * TilePartitioner::NPerBlock * sizeof(DataType);
402  void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
403  kargs.tile_partitioner.get_flags_buffer_size() +
404  cta_idx * c_block_tile_buffer_size;
405 
406  const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
407  static_cast<DataType*>(partial_buffer_ptr),
409  make_tuple(TilePartitioner::NPerBlock, 1),
410  number<GemmPipeline::GetVectorSizeC()>{},
411  number<1>{});
412 
413  auto partial_tile_window = make_tile_window(
414  partial_tensor_view,
416  {0, 0},
417  c_block_tile_dist);
418 
419  return load_tile(partial_tile_window);
420  }
421 
430  template <typename OAccTile>
432  index_t cta_idx,
433  const OAccTile& c_block_tile) const
434  {
435  const auto c_block_tile_buffer_size = TilePartitioner::MPerBlock *
436  TilePartitioner::NPerBlock *
437  sizeof(typename OAccTile::DataType);
438  void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
439  kargs.tile_partitioner.get_flags_buffer_size() +
440  cta_idx * c_block_tile_buffer_size;
441 
442  const auto& partial_tensor_view = make_naive_tensor_view<
443  address_space_enum::global,
444  memory_operation_enum::set,
445  StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
446  static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
448  make_tuple(TilePartitioner::NPerBlock, 1),
449  number<GemmPipeline::GetVectorSizeC()>{},
450  number<1>{});
451 
452  auto partial_tile_window = make_tile_window(
453  partial_tensor_view,
455  {0, 0});
456  store_tile(partial_tile_window, c_block_tile);
457  // Wait for all vector stores for this wavefront to complete
458  s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
459  // Wait for all wavefronts in this workgroup to arrive here before continuing
460  __builtin_amdgcn_s_barrier();
461  }
462 
473  void StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
474  {
475  index_t iter_start, iter_end;
476  kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx);
477 
478  while(iter_start < iter_end)
479  {
480  // Get the 1D tile index in the C tensor that this workgroup will work in for this
481  // iteration of the loop.
482  index_t tile_idx =
483  amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start));
484 
485  // Get the start and end boundaries for the current tile.
486  index_t tile_iter_start, tile_iter_end;
487  kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx);
488 
489  // Get the start and end iteration within the current tile for the workgroup.
490  index_t local_iter_start = amd_wave_read_first_lane(
491  kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start));
492  index_t local_iter_end =
493  amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end(
494  tile_iter_start, iter_end, tile_iter_end));
495 
496  // Get the iteration length.
497  index_t num_loop_sk = local_iter_end - local_iter_start;
498 
499  // Determine the total size along the K dimension the workgroup is using in this
500  // iteration (used to construct tensor views).
501  index_t k_size = num_loop_sk * TilePartitioner::KPerBlock;
502 
503  // Get the K offsets for the A and B tensors
504  auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
505  local_iter_start, kargs.stride_As[0], kargs.stride_Bs[0]);
506 
507  if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Atomic)
508  {
509  BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
510  }
511  else if(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Reduction ||
512  TilePartitioner::ReductionStrategy == StreamKReductionStrategy::TreeReduction)
513  {
514  const auto c_macro_tile_idx =
515  kargs.tile_partitioner.get_output_tile_index(tile_idx);
516  index_t i_m =
517  c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
518  index_t i_n =
519  c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
520 
521  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
522  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
523  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
524 
525  // Create block windows using specialized methods
526  const auto& as_block_window =
527  UniversalGemmKernel::MakeABlockWindows({a_ptr}, kargs, k_size, i_m);
528  const auto& bs_block_window =
529  UniversalGemmKernel::MakeBBlockWindows({b_ptr}, kargs, k_size, i_n);
530  const auto& ds_block_window =
531  UniversalGemmKernel::MakeDBlockWindows({/*ds_ptr*/}, kargs, i_m, i_n);
532 
533  // Since num_loop can vary per WG and per iteration of the Stream-K while loop,
534  // we compute has_hot_loop and tail_num here. This is a similar pattern used by
535  // grouped GEMM. In this case, we call the GemmPipeline's operator() function
536  // that takes both has_hot_loop and tail_num.
537  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
538  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
539 
540  // Run GEMM cooperatively by whole workgroup.
541  const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
542  bs_block_window[UniversalGemmKernel::I0],
543  num_loop_sk,
544  has_hot_loop,
545  tail_num,
546  smem_ptr_0);
547 
548  auto tile_started = iter_start == tile_iter_start;
549  auto tile_ended = iter_end >= tile_iter_end;
550 
551  if constexpr(TilePartitioner::ReductionStrategy ==
553  {
554  if(!tile_started)
555  {
556  StorePartial(kargs, cta_idx, c_block_tile);
557  SignalStorePartialDone(kargs, cta_idx);
558  }
559  else
560  {
561  auto accum_block_tile = c_block_tile;
562  if(!tile_ended)
563  {
564  const index_t iter_per_tile =
565  kargs.tile_partitioner.get_iters_per_tile();
566  const index_t iter_per_cta =
567  kargs.tile_partitioner.get_iters_per_sk_cta();
568  const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
569  int accum_iters = local_iter_end - local_iter_start;
570  int next_cta = cta_idx + 1;
571 
572  while(accum_iters < iter_per_tile)
573  {
574  WaitStorePartialDone(kargs, next_cta);
575 
576  using BlockType = remove_cvref_t<decltype(c_block_tile)>;
577  AddBlockTile(
578  accum_block_tile,
579  LoadPartial<typename BlockType::DataType>(
580  kargs, next_cta, c_block_tile.get_tile_distribution()));
581 
582  accum_iters += iter_per_cta + (next_cta < extra_iters);
583  ++next_cta;
584  }
585  }
586 
587  auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows<
588  TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n);
590  c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
591  }
592  }
593  else // Tree Reduction
594  {
595  auto accum_block_tile = c_block_tile;
596  index_t tile_local_cta_idx =
597  kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx);
598 
599  for(index_t stride = 1;; stride <<= 1)
600  {
601  const index_t partner_cta_idx = cta_idx + stride;
602  const index_t partner_start_iter =
603  kargs.tile_partitioner.get_start_iter(partner_cta_idx);
604  bool partner_in_tile = partner_start_iter < tile_iter_end;
605 
606  // If the partner of the workgroup who started the tile is not in this tile,
607  // then the work for this tile is done and results can be stored in the C
608  // tensor.
609  if(tile_started && !partner_in_tile)
610  {
611  auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows<
612  TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n);
614  c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
615  break;
616  }
617 
618  // It's this workgroup's turn to read from partials.
619  if(tile_local_cta_idx % (stride << 1) == 0)
620  {
621  // If this workgroup's partner is in the tile then it can read from
622  // partials and accumulate results.
623  if(partner_in_tile)
624  {
625  WaitStorePartialDone(kargs, partner_cta_idx);
626  using BlockType = remove_cvref_t<decltype(c_block_tile)>;
627  AddBlockTile(accum_block_tile,
628  LoadPartial<typename BlockType::DataType>(
629  kargs,
630  partner_cta_idx,
631  c_block_tile.get_tile_distribution()));
632  }
633  }
634  // Otherwise, it's this workgroup's turn to write to partials. All
635  // workgroups, except the workgroup who starts the tile, will write to
636  // partials.
637  else
638  {
639  StorePartial(kargs, cta_idx, accum_block_tile);
640  SignalStorePartialDone(kargs, cta_idx);
641  // Once the workgroup writes to partials, it has no more work to do for
642  // this tile.
643  break;
644  }
645  }
646  }
647  }
648  else
649  {
650  static_assert(
651  "An implementation does not exist for the chosen reduction strategy.");
652  }
653 
654  // Prepare for next Stream-K loop iteration.
655  iter_start = tile_iter_end;
656  block_sync_lds();
657  }
658  }
659 
669  template <bool U = PersistentDP>
670  CK_TILE_DEVICE typename std::enable_if_t<!U> operator()(StreamKKernelArgs kargs) const
671  {
672  // Allocate LDS
673  __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
674 
675  index_t block_idx = ck_tile::get_block_1d_id();
676  index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
677  index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas();
678  bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas();
679 
680  // Check if at the data parallel section
681  if(is_dp_ctas)
682  {
683  BaseGemm(kargs, block_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
684  }
685  else
686  {
687  // Stream-K
688  StreamKGemm(kargs, block_idx - dp_ctas, smem_ptr_0);
689  }
690  }
691 
702  template <bool U = PersistentDP>
703  CK_TILE_DEVICE typename std::enable_if_t<U> operator()(StreamKKernelArgs kargs) const
704  {
705  // Allocate LDS
706  __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
707 
708  index_t block_idx = ck_tile::get_block_1d_id();
709  index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
710 
711  // Data-parallel section
712  for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
713  tile_idx += kargs.tile_partitioner.get_grid())
714  {
715  BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
716  block_sync_lds();
717  }
718 
719  // Stream-K section
720  StreamKGemm(kargs, block_idx, smem_ptr_0);
721  }
722 
723  private:
732  template <typename ALayout, typename BLayout>
734  GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
735  {
736  index_t stride_offset_a;
737  index_t stride_offset_b;
738  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
739  {
740  stride_offset_a = stride_a;
741  }
742  else
743  {
744  stride_offset_a = 1;
745  }
746 
747  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
748  {
749  stride_offset_b = stride_b;
750  }
751  else
752  {
753  stride_offset_b = 1;
754  }
755 
756  index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
757 
758  return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
759  }
760 
761  CK_TILE_HOST static int NumCU()
762  {
763  hipDeviceProp_t dev_prop;
764  hipDevice_t dev;
765  ck_tile::hip_check_error(hipGetDevice(&dev));
766  ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
767  int num_cu = dev_prop.multiProcessorCount;
768 
769  return num_cu;
770  }
771 
779  CK_TILE_HOST static int Occupancy()
780  {
781  int occupancy;
782 
783  // Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
784  constexpr int min_block_per_cu = 1;
785  const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
786 
788  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
789 
790  return max(occupancy, 1);
791  }
792 };
793 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ TreeReduction
Definition: streamk_common.hpp:13
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:486
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
@ Atomic
Definition: block_to_ctile_map.hpp:1012
@ Reduction
Definition: block_to_ctile_map.hpp:1013
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
unsigned int uint32_t
Definition: stdint.h:126
Definition: streamk_gemm_coherency.hpp:10
The Stream K GEMM kernel host arguments.
Definition: streamk_gemm_kernel.hpp:23
CK_TILE_HOST StreamKHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: streamk_gemm_kernel.hpp:24
ALayout and ADataType are expected to be scalars, not a tuple.
Definition: streamk_gemm_kernel.hpp:119
void * workspace_ptr
A pointer to a buffer in device memory for accumulating partial via reduction strategy.
Definition: streamk_gemm_kernel.hpp:144
TilePartitioner tile_partitioner
An instance of the TilePartioner class for assisting with mapping workgroups to the C tensor.
Definition: streamk_gemm_kernel.hpp:149
StreamKKernelArgs(const StreamKHostArgs &host_args, index_t grid)
Definition: streamk_gemm_kernel.hpp:120
The Stream K GEMM kernel class.
Definition: streamk_gemm_kernel.hpp:64
CK_TILE_DEVICE void StreamKGemm(StreamKKernelArgs &kargs, index_t cta_idx, void *smem_ptr_0) const
Runs the main Stream - K algorithm.
Definition: streamk_gemm_kernel.hpp:473
GemmPipeline_ GemmPipeline
Definition: streamk_gemm_kernel.hpp:76
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: streamk_gemm_kernel.hpp:70
typename GemmPipeline::BLayout BLayout
Definition: streamk_gemm_kernel.hpp:87
typename GemmPipeline::ALayout ALayout
Specify the layout configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:86
typename EpiloguePipeline::ODataType CDataType
Definition: streamk_gemm_kernel.hpp:95
static CK_TILE_HOST auto GridSize(const TilePartitioner &tile_partitioner) -> dim3
Compute the grid size for the Stream K kernel using the tile_partitioner.
Definition: streamk_gemm_kernel.hpp:173
CK_TILE_DEVICE std::enable_if_t< U > operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel with persistent DP.
Definition: streamk_gemm_kernel.hpp:703
typename GemmPipeline::CLayout CLayout
Definition: streamk_gemm_kernel.hpp:88
static constexpr bool is_tuple_v
Definition: streamk_gemm_kernel.hpp:99
static CK_TILE_HOST StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs &host_args, int num_cu=NumCU(), int occupancy=Occupancy())
Constructs kernel arguments for the Stream-K kernel.
Definition: streamk_gemm_kernel.hpp:203
typename GemmPipeline::BDataType BDataType
Definition: streamk_gemm_kernel.hpp:94
TilePartitioner_ TilePartitioner
Definition: streamk_gemm_kernel.hpp:75
static constexpr bool PersistentDP
Definition: streamk_gemm_kernel.hpp:73
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, UniversalGemmKernel::NumATensor > &as_ptr, const std::array< const BDataType *, UniversalGemmKernel::NumBTensor > &bs_ptr, const std::array< const void *, UniversalGemmKernel::NumDTensor > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const typename UniversalGemmKernel::KernelArgs &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t k_size)
Definition: streamk_gemm_kernel.hpp:214
typename GemmPipeline::ADataType ADataType
Specify the data type configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:93
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: streamk_gemm_kernel.hpp:184
static CK_TILE_HOST void SetWorkSpacePointer(StreamKKernelArgs &kargs, void *workspace_ptr)
Sets the kargs' current workspace_ptr to the given workspace_ptr.
Definition: streamk_gemm_kernel.hpp:276
CK_TILE_DEVICE void AddBlockTile(OAccTile &in_out_block_tile, const OAccTile &in_block_tile) const
Adds the values of a block tile to an output block tile.
Definition: streamk_gemm_kernel.hpp:373
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs &kargs, index_t cta_idx) const
Signals that the current thread block(CTA) has completed storing its partial results.
Definition: streamk_gemm_kernel.hpp:321
static constexpr index_t kBlockSize
Definition: streamk_gemm_kernel.hpp:72
typename EpiloguePipeline::AccDataType AccDataType
Definition: streamk_gemm_kernel.hpp:96
EpiloguePipeline_ EpiloguePipeline
Definition: streamk_gemm_kernel.hpp:77
static CK_TILE_HOST const std::string GetName()
Definition: streamk_gemm_kernel.hpp:155
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: streamk_gemm_kernel.hpp:189
static CK_TILE_HOST bool IsSupportedArgument(const StreamKKernelArgs &kargs)
Definition: streamk_gemm_kernel.hpp:259
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs &kargs, index_t cta_idx, const OAccTile &c_block_tile) const
Stores a partial block tile to the workspace buffer.
Definition: streamk_gemm_kernel.hpp:431
CK_TILE_DEVICE void BaseGemm(StreamKKernelArgs &kargs, index_t tile_idx, index_t num_loop, index_t i_k_a, index_t i_k_b, index_t k_size, void *smem_ptr_0) const
Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue.
Definition: streamk_gemm_kernel.hpp:293
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs &kargs, index_t cta_idx) const
Waits for the thread block (cta_idx) to complete storing its partial results.
Definition: streamk_gemm_kernel.hpp:345
static CK_TILE_HOST uint32_t GetWorkSpaceSize(const StreamKKernelArgs &kargs)
Computes the buffer size needed to store accumulation results for Stream K.
Definition: streamk_gemm_kernel.hpp:268
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs &kargs, index_t cta_idx, const OAccTileDist &c_block_tile_dist) const
Loads a partial block tile from the workspace buffer.
Definition: streamk_gemm_kernel.hpp:396
CK_TILE_DEVICE std::enable_if_t<!U > operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel with non-persistent DP.
Definition: streamk_gemm_kernel.hpp:670
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
std::array< index_t, NumBTensor > stride_Bs
The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
Definition: universal_gemm_kernel.hpp:106
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:92
std::array< index_t, NumATensor > stride_As
The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
Definition: universal_gemm_kernel.hpp:103
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
index_t k_batch
Definition: universal_gemm_kernel.hpp:113
index_t N
GEMM's N dimension size.
Definition: universal_gemm_kernel.hpp:98
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: universal_gemm_kernel.hpp:112
index_t K
GEMM's K dimension size.
Definition: universal_gemm_kernel.hpp:100
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: universal_gemm_kernel.hpp:109
index_t M
GEMM's M dimension size.
Definition: universal_gemm_kernel.hpp:96
static CK_TILE_DEVICE auto MakeBBlockWindows(const std::array< const BDataType *, NumBTensor > &bs_ptr, const KernelArgs &kargs, const index_t k_size, const index_t i_n)
Definition: universal_gemm_kernel.hpp:742
static constexpr bool PersistentKernel
Definition: universal_gemm_kernel.hpp:217
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:237
static CK_TILE_DEVICE auto MakeDBlockWindows(const std::array< const void *, NumDTensor > &ds_ptr, const KernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:893
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:292
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:280
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:236
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:400
static CK_TILE_DEVICE auto MakeABlockWindows(const std::array< const ADataType *, NumATensor > &as_ptr, const KernelArgs &kargs, const index_t k_size, const index_t i_m)
Definition: universal_gemm_kernel.hpp:665
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:321
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202
Definition: integral_constant.hpp:13
Definition: coordinate_transform.hpp:1392
Definition: stream_config.hpp:30
Definition: tuple.hpp:192