/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File
grouped_gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
13 #include "ck_tile/host.hpp"
14 
15 #include <hip/hip_runtime.h>
16 
17 namespace ck_tile {
18 
26 
27 template <index_t NumDTensor = 0>
29 {
30  CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
31  const void* b_ptr_,
32  const std::array<const void*, NumDTensor>& ds_ptr_,
33  void* e_ptr_,
34  index_t k_batch_,
35  index_t M_,
36  index_t N_,
37  index_t K_,
38  index_t stride_A_,
39  index_t stride_B_,
40  const std::array<index_t, NumDTensor>& stride_Ds_,
41  index_t stride_E_)
42  : a_ptr(a_ptr_),
43  b_ptr(b_ptr_),
44  ds_ptr(ds_ptr_),
45  e_ptr(e_ptr_),
46  M(M_),
47  N(N_),
48  K(K_),
49  stride_A(stride_A_),
50  stride_B(stride_B_),
51  stride_Ds(stride_Ds_),
52  stride_E(stride_E_),
53  k_batch(k_batch_)
54  {
55  }
56 
57  const void* a_ptr;
58  const void* b_ptr;
59  const std::array<const void*, NumDTensor> ds_ptr;
60  union
61  {
62  void* e_ptr;
63  void* c_ptr;
64  };
65 
71  const std::array<index_t, NumDTensor> stride_Ds;
72  union
73  {
76  };
77 
79 };
80 
81 template <index_t NumDTensor = 0>
83 {
87 
88  GemmTransKernelArg() = delete;
90  index_t bl_start,
91  index_t bl_end)
92  : group_karg{std::move(karg)}, block_start{bl_start}, block_end{bl_end}
93  {
94  }
95 
97  : group_karg{std::move(karg)}, block_start{0}, block_end{0}
98  {
99  }
100 };
101 
102 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
104 {
108 
112 
117 
123 
124  static constexpr index_t NumDTensor_ = DsDataType::size();
125 
127  static_assert(
129  "ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
130 
132  static_assert(
134  "BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
135 
137  static_assert(!is_detected<is_tuple, CLayout>::value &&
139  "C/CLayout and C/EDataType must be scalars.");
140 
143 
144  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
145  static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
146 
147  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
148  {
149  // clang-format off
150  using P_ = GemmPipeline;
151 
152  return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
153  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
154  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
155  concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
156  (UsePersistentKernel ? "Persistent" : "NonPersistent"),
157  (NumDTensor_ == 2 ? "MultiD" : "NoMultiD"),
158  (GemmPipeline::DoubleSmemBuffer ? "DoubleSmemBuffer" : "SingleSmemBuffer"));
159  // clang-format on
160  }
161 
162  CK_TILE_HOST static auto
163  GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs<>>& gemm_descs) -> std::size_t
164  {
165  return gemm_descs.size() * sizeof(GemmTransKernelArg<NumDTensor_>);
166  }
167 
168  CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
169  {
170  return group_count * sizeof(GemmTransKernelArg<NumDTensor_>);
171  }
172 
173  CK_TILE_HOST static auto BlockSize() -> dim3
174  {
175  if(is_wave32())
176  {
177  return dim3(kBlockSize / 2);
178  }
179  else
180  {
181  return dim3(kBlockSize);
182  }
183  }
184 
191  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
192  {
193  using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
194  const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>;
195  int occupancy;
197  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
198  const int grid_size = get_available_compute_units(s) * occupancy;
199  return dim3(grid_size, 1, 1);
200  }
201 
202  CK_TILE_HOST static auto
203  GridSize(const std::vector<GroupedGemmHostArgs<NumDTensor_>>& gemm_descs)
204  {
205  index_t grid_size = 0;
206  for(const auto& it_desc : gemm_descs)
207  {
208  const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
209  grid_size += local_grid_size * it_desc.k_batch;
210  }
211  return dim3(grid_size, 1, 1);
212  }
213 
214  CK_TILE_HOST static auto
215  MakeKargs(const std::vector<GroupedGemmHostArgs<NumDTensor_>>& gemm_descs)
216  -> std::vector<GemmTransKernelArg<NumDTensor_>>
217  {
218  std::vector<GemmTransKernelArg<NumDTensor_>> gemm_kernel_args_;
219  index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
220  index_t grid_size = 0;
221  gemm_kernel_args_.reserve(group_count);
222 
223  for(std::size_t i = 0; i < gemm_descs.size(); ++i)
224  {
225  const index_t M = gemm_descs[i].M;
226  const index_t N = gemm_descs[i].N;
227  const index_t K = gemm_descs[i].K;
228 
229  if(M == 0 || N == 0 || K == 0)
230  {
231  continue;
232  }
233 
234  const index_t stride_a = gemm_descs[i].stride_A;
235  const index_t stride_b = gemm_descs[i].stride_B;
236  const index_t stride_e = gemm_descs[i].stride_E;
237  auto stride_ds = gemm_descs[i].stride_Ds;
238 
239  const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
240 
241  const index_t block_start = grid_size;
242  const index_t block_end = grid_size + grid_size_grp;
243 
244  grid_size += grid_size_grp;
245 
247  {type_convert<const ADataType*>(gemm_descs[i].a_ptr)},
248  {type_convert<const BDataType*>(gemm_descs[i].b_ptr)},
249  {gemm_descs[i].ds_ptr},
250  type_convert<CDataType*>(gemm_descs[i].e_ptr),
251  M,
252  N,
253  K,
254  {stride_a},
255  {stride_b},
256  stride_ds,
257  stride_e,
258  gemm_descs[i].k_batch};
259 
260  gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
261  }
262 
263  return gemm_kernel_args_;
264  }
265 
266  CK_TILE_HOST static bool
268  {
269  for(const auto& karg : kargs)
270  {
271  if(!Base::IsSupportedArgument(karg.group_karg))
272  {
273  return false;
274  }
275  }
276  return true;
277  }
278 
279  CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
280  {
281  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
282  }
283 
285  const tuple<index_t, index_t>& block_idx_2d,
286  const index_t block_idx_z) const
287  {
288 
289  static_assert(GemmPipeline::DoubleSmemBuffer || !GemmPipeline::Preshuffle,
290  "SingleSmemBuffer and Preshuffle cannot both be enabled simultaneously!");
291 
292  const auto [iM, iN] = block_idx_2d;
293 
294  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
295  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
296 
297  const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
298 
299  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) +
300  splitk_batch_offset.as_k_split_offset[0];
301  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) +
302  splitk_batch_offset.bs_k_split_offset[0];
303  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
304 
305  // allocate LDS
306  __shared__ char smem_ptr_0[GetSmemSize()];
307 
308  // TO DO:
309  // Can we simplify this branching logic?
310  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
311  {
312 
313  __shared__ char smem_ptr_1[GetSmemSize()];
315  b_ptr,
316  c_ptr,
317  kargs.ds_ptr,
318  smem_ptr_0,
319  smem_ptr_1,
320  kargs,
321  splitk_batch_offset,
322  i_m,
323  i_n);
324  }
325  else // SingleSmemBuffer
326  {
327 
328  if constexpr(UsePersistentKernel)
329  {
331  b_ptr,
332  kargs.ds_ptr,
333  c_ptr,
334  smem_ptr_0,
335  kargs,
336  splitk_batch_offset,
337  i_m,
338  i_n);
339  }
340  else // Non-persistent kernel
341  {
342  Base::RunGemm({a_ptr},
343  {b_ptr},
344  kargs.ds_ptr,
345  c_ptr,
346  smem_ptr_0,
347  kargs,
348  splitk_batch_offset,
349  i_m,
350  i_n);
351  }
352  }
353  }
354 
373  CK_TILE_DEVICE static void
375  const BDataType* b_ptr,
376  const std::array<const void*, NumDTensor_>& ds_ptr,
377  CDataType* c_ptr,
378  void* smem_ptr_0,
380  const typename Base::SplitKBatchOffset& splitk_batch_offset,
381  const index_t block_idx_m,
382  const index_t block_idx_n)
383  {
384  // Create Gemm tensor views, pad views and tile windows
385  const auto& gemm_tensor_views_tuple =
386  Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
387  {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
388 
389  const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
390  auto gemm_tile_windows =
391  Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
392  const auto& a_block_window = gemm_tile_windows.at(Base::I0);
393  const auto& b_block_window = gemm_tile_windows.at(Base::I1);
394  const auto& d_block_window = gemm_tile_windows.at(Base::I2);
395 
396  // Get hot-loop and tail configuration
397  const index_t num_loop =
398  amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
399  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
400  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
401 
402  // Run GEMM pipeline
403  const auto& c_block_tile = GemmPipeline{}.template operator()(
404  a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
405  // Run Epilogue Pipeline
406  auto& c_block_window = gemm_tile_windows.at(Base::I3);
407  EpiloguePipeline{}.template
408  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
409  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
410  }
411 
431  CK_TILE_DEVICE static void
433  const BDataType* b_ptr,
434  CDataType* c_ptr,
435  const std::array<const void*, NumDTensor_>& ds_ptr,
436  void* __restrict__ smem_ptr_0,
437  void* __restrict__ smem_ptr_1,
439  const typename Base::SplitKBatchOffset& splitk_batch_offset,
440  const index_t block_idx_m,
441  const index_t block_idx_n)
442  {
443  // Create Gemm tensor views, pad views and tile windows
444  const auto& gemm_tensor_views_tuple =
445  Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
446  {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
447 
448  const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
449  auto gemm_tile_windows =
450  Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
451  const auto& a_block_window = gemm_tile_windows.at(Base::I0);
452  const auto& b_block_window = gemm_tile_windows.at(Base::I1);
453  const auto& d_block_window = gemm_tile_windows.at(Base::I2);
454 
455  // Get hot-loop and tail configuration
456  const index_t num_loop =
457  amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
458  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
459 
460  // Run GEMM pipeline with compile-time branching
461  const auto& c_block_tile = [&]() {
462  if constexpr(GemmPipeline::Preshuffle)
463  {
464  // Preshuffle version - without has_hot_loop parameter
465  return GemmPipeline{}.template operator()(a_block_window[Base::I0],
466  b_block_window[Base::I0],
467  num_loop,
468  tail_num,
469  smem_ptr_0,
470  smem_ptr_1);
471  }
472  else
473  {
474  // Regular version - with has_hot_loop parameter
475  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
476  return GemmPipeline{}.template operator()(a_block_window[Base::I0],
477  b_block_window[Base::I0],
478  num_loop,
479  has_hot_loop,
480  tail_num,
481  smem_ptr_0,
482  smem_ptr_1);
483  }
484  }();
485 
486  // Run Epilogue Pipeline
487  auto& c_block_window = gemm_tile_windows.at(Base::I3);
488  EpiloguePipeline{}.template
489  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
490  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
491  }
492 
494  index_t block_id,
495  index_t group_count) const
496  {
497  index_t left = 0;
498  index_t right = group_count;
499  index_t group_id = index_t((left + right) >> 1);
500 
501  while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
502  block_id < gemm_desc_ptr[group_id].block_end)) &&
503  left <= right)
504  {
505  if(block_id < gemm_desc_ptr[group_id].block_start)
506  {
507  right = group_id;
508  }
509  else
510  {
511  left = group_id;
512  }
513  group_id = index_t((left + right) >> 1);
514  }
515 
516  return group_id;
517  }
518 
519  // For non-persistent kernels
520  template <bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
521  CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
522  index_t group_count) const
523  {
524  const index_t block_id = ck_tile::get_block_1d_id();
525  const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg<NumDTensor_>*>(
526  cast_pointer_to_generic_address_space(gemm_descs_const));
527 
528  const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
529  const auto& kargs = gemm_desc_ptr[group_id];
530 
531  const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
532  const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
533  0,
534  kargs.group_karg.M,
535  kargs.group_karg.N,
536  (block_id - kargs.block_start) % grid_size_2d);
537  Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d);
538  }
539 
540  // For persistent kernels
541  template <bool U = UsePersistentKernel,
542  typename = std::enable_if_t<U>,
543  typename = void> // extra template parameter to avoid redefinition
544  CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
545  const index_t group_count) const
546  {
547  const index_t grid_size = ck_tile::get_grid_size();
548  const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg<NumDTensor_>*>(
549  cast_pointer_to_generic_address_space(gemm_descs_const));
550  index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
551  index_t cum_grid_size = 0;
552  for(index_t group_id = 0; group_id < group_count; ++group_id)
553  {
554  const auto& kargs = gemm_desc_ptr[group_id].group_karg;
555  const auto& k_batch = kargs.k_batch;
556  const auto block_start = cum_grid_size;
557  cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
558  while(block_id < cum_grid_size)
559  {
560  const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
561  const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
562  0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
563  Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
564  block_id = block_id + grid_size; // advance to next block
565  // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
566  if(block_id >= cum_grid_size)
567  {
568  break; // exit the loop if all blocks are processed
569  }
570  }
571  }
572  }
573 };
574 
575 } // namespace ck_tile
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: arch.hpp:217
#define CK_CONSTANT_ADDRESS_SPACE
Definition: ck.hpp:23
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:21
Definition: cluster_descriptor.hpp:13
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
__device__ index_t get_grid_size()
Definition: get_id.hpp:49
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
Definition: grouped_gemm_kernel.hpp:83
GemmTransKernelArg(UniversalGemmKernelArgs< 1, 1, NumDTensor > &&karg, index_t bl_start, index_t bl_end)
Definition: grouped_gemm_kernel.hpp:89
UniversalGemmKernelArgs< 1, 1, NumDTensor > group_karg
Definition: grouped_gemm_kernel.hpp:84
GemmTransKernelArg(UniversalGemmKernelArgs< 1, 1, NumDTensor > &&karg)
Definition: grouped_gemm_kernel.hpp:96
ck_tile::index_t block_start
Definition: grouped_gemm_kernel.hpp:85
ck_tile::index_t block_end
Definition: grouped_gemm_kernel.hpp:86
The Grouped GEMM kernel host arguments.
Definition: grouped_gemm_kernel.hpp:29
void * e_ptr
Definition: grouped_gemm_kernel.hpp:62
index_t stride_E
Definition: grouped_gemm_kernel.hpp:74
CK_TILE_HOST GroupedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: grouped_gemm_kernel.hpp:30
index_t stride_C
Definition: grouped_gemm_kernel.hpp:75
index_t k_batch
Definition: grouped_gemm_kernel.hpp:78
index_t stride_A
Definition: grouped_gemm_kernel.hpp:69
index_t M
Definition: grouped_gemm_kernel.hpp:66
void * c_ptr
Definition: grouped_gemm_kernel.hpp:63
index_t stride_B
Definition: grouped_gemm_kernel.hpp:70
const void * b_ptr
Definition: grouped_gemm_kernel.hpp:58
const void * a_ptr
Definition: grouped_gemm_kernel.hpp:57
index_t N
Definition: grouped_gemm_kernel.hpp:67
index_t K
Definition: grouped_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_gemm_kernel.hpp:59
const std::array< index_t, NumDTensor > stride_Ds
Definition: grouped_gemm_kernel.hpp:71
Definition: grouped_gemm_kernel.hpp:104
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_gemm_kernel.hpp:109
static constexpr index_t NumDTensor_
Definition: grouped_gemm_kernel.hpp:124
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count) const
Definition: grouped_gemm_kernel.hpp:544
static CK_TILE_HOST auto GetWorkSpaceSize(index_t group_count) -> std::size_t
Definition: grouped_gemm_kernel.hpp:168
static CK_TILE_HOST bool IsSupportedArgument(const std::vector< GemmTransKernelArg< NumDTensor_ >> &kargs)
Definition: grouped_gemm_kernel.hpp:267
static CK_TILE_DEVICE void RunGemmWithPipelineSelection(const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor_ > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const UniversalGemmKernelArgs< 1, 1, NumDTensor_ > &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_gemm_kernel.hpp:374
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: grouped_gemm_kernel.hpp:114
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs< 1, 1, NumDTensor_ > &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
Definition: grouped_gemm_kernel.hpp:284
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: grouped_gemm_kernel.hpp:115
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: grouped_gemm_kernel.hpp:120
static constexpr index_t kBlockSize
Definition: grouped_gemm_kernel.hpp:144
static constexpr CK_TILE_HOST_DEVICE auto GetSmemSize() -> index_t
Definition: grouped_gemm_kernel.hpp:279
static CK_TILE_HOST auto GridSize(const std::vector< GroupedGemmHostArgs< NumDTensor_ >> &gemm_descs)
Definition: grouped_gemm_kernel.hpp:203
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: grouped_gemm_kernel.hpp:116
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_gemm_kernel.hpp:111
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg< NumDTensor_ > *gemm_desc_ptr, index_t block_id, index_t group_count) const
Definition: grouped_gemm_kernel.hpp:493
static CK_TILE_HOST auto BlockSize() -> dim3
Definition: grouped_gemm_kernel.hpp:173
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: grouped_gemm_kernel.hpp:121
static CK_TILE_DEVICE void RunGemmWithPipelineSelection2LDS(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, const std::array< const void *, NumDTensor_ > &ds_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const UniversalGemmKernelArgs< 1, 1, NumDTensor_ > &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_gemm_kernel.hpp:432
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: grouped_gemm_kernel.hpp:191
static CK_TILE_HOST const std::string GetName()
Definition: grouped_gemm_kernel.hpp:147
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_gemm_kernel.hpp:122
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_gemm_kernel.hpp:110
static CK_TILE_HOST auto MakeKargs(const std::vector< GroupedGemmHostArgs< NumDTensor_ >> &gemm_descs) -> std::vector< GemmTransKernelArg< NumDTensor_ >>
Definition: grouped_gemm_kernel.hpp:215
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, C/E.
Definition: grouped_gemm_kernel.hpp:119
static CK_TILE_HOST auto GetWorkSpaceSize(const std::vector< GroupedGemmHostArgs<>> &gemm_descs) -> std::size_t
Definition: grouped_gemm_kernel.hpp:163
static constexpr bool UsePersistentKernel
Definition: grouped_gemm_kernel.hpp:145
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, index_t group_count) const
Definition: grouped_gemm_kernel.hpp:521
Struct used to calculate offseted tile indexes.
Definition: gemm_tile_partitioner.hpp:184
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: universal_gemm_kernel.hpp:325
std::array< index_t, NumATensor > as_k_split_offset
Definition: universal_gemm_kernel.hpp:368
index_t splitted_k
Definition: universal_gemm_kernel.hpp:370
std::array< index_t, NumBTensor > bs_k_split_offset
Definition: universal_gemm_kernel.hpp:369
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:92
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:955
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:238
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:853
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:239
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:754
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:237
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:236
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
Definition: stream_config.hpp:30
Definition: tuple.hpp:192