/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp Source File
grouped_gemm_quant_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
14 #include "ck_tile/host.hpp"
15 
16 #include <hip/hip_runtime.h>
17 
18 namespace ck_tile {
19 
28 {
30  const void* b_ptr_,
31  void* e_ptr_,
32  const void* aq_ptr_,
33  const void* bq_ptr_,
34  index_t k_batch_,
35  index_t M_,
36  index_t N_,
37  index_t K_,
38  index_t QK_A_,
39  index_t QK_B_,
40  index_t stride_A_,
41  index_t stride_B_,
42  index_t stride_E_,
43  index_t stride_AQ_,
44  index_t stride_BQ_)
45  : a_ptr(a_ptr_),
46  b_ptr(b_ptr_),
47  aq_ptr(aq_ptr_),
48  bq_ptr(bq_ptr_),
49  e_ptr(e_ptr_),
50  M(M_),
51  N(N_),
52  K(K_),
53  QK_A(QK_A_),
54  QK_B(QK_B_),
55  stride_A(stride_A_),
56  stride_B(stride_B_),
57  stride_AQ(stride_AQ_),
58  stride_BQ(stride_BQ_),
59  stride_E(stride_E_),
60  k_batch(k_batch_)
61  {
62  }
63 
64  const void* a_ptr;
65  const void* b_ptr;
66  const void* aq_ptr;
67  const void* bq_ptr;
68  union
69  {
70  void* e_ptr;
71  void* c_ptr;
72  };
73 
83 
84  union
85  {
88  };
89 
91 };
92 
94 
96 {
100 
103  : group_karg{karg}, block_start{bl_start}, block_end{bl_end}
104  {
105  }
106 
108  : group_karg{karg}, block_start{0}, block_end{0}
109  {
110  }
111 };
112 
113 template <typename TilePartitioner_,
114  typename GemmPipeline_,
115  typename EpiloguePipeline_,
116  QuantType QuantType_>
118 {
122 
126 
131 
137 
138  using AQDataType =
140  using BQDataType =
142 
143  static constexpr auto kQuantType = QuantType_;
144 
146  static_assert(
148  "ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
149 
151  static_assert(
153  "BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
154 
156  static_assert(!is_detected<is_tuple, CLayout>::value &&
158  "C/ELayout and C/EDataType must be scalars.");
159 
161  using Kernel =
163 
164  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
165  static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
166 
167  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
168  {
169  // clang-format off
170  using P_ = GemmPipeline;
171 
172  return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
173  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
174  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
175  concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
176  (UsePersistentKernel ? "Persistent" : "NonPersistent"));
177  // clang-format on
178  }
179 
180  CK_TILE_HOST static auto
181  GetWorkSpaceSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs) -> std::size_t
182  {
183  return gemm_descs.size() * sizeof(QuantGemmTransKernelArg);
184  }
185 
186  CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
187  {
188  return group_count * sizeof(QuantGemmTransKernelArg);
189  }
190 
191  CK_TILE_HOST static auto BlockSize() -> dim3
192  {
193  if(is_wave32())
194  {
195  return dim3(kBlockSize / 2);
196  }
197  else
198  {
199  return dim3(kBlockSize);
200  }
201  }
202 
209  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
210  {
211  using ConstantPointer = const void CK_TILE_CONSTANT_ADDRESS_SPACE*;
212  const auto kernel_func = kentry<1, Kernel, ConstantPointer, index_t>;
213  int occupancy;
215  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel_func, kBlockSize, 0));
216  const int grid_size = get_available_compute_units(s) * occupancy;
217  return dim3(grid_size, 1, 1);
218  }
219 
220  CK_TILE_HOST static auto GridSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
221  {
222  index_t grid_size = 0;
223  for(const auto& it_desc : gemm_descs)
224  {
225  const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
226  grid_size += local_grid_size * it_desc.k_batch;
227  }
228  return dim3(grid_size, 1, 1);
229  }
230 
231  CK_TILE_HOST static auto MakeKargs(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
232  -> std::vector<QuantGemmTransKernelArg>
233  {
234  std::vector<QuantGemmTransKernelArg> gemm_kernel_args_;
235  index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
236  index_t grid_size = 0;
237  gemm_kernel_args_.reserve(group_count);
238 
239  for(std::size_t i = 0; i < gemm_descs.size(); ++i)
240  {
241  const index_t M = gemm_descs[i].M;
242  const index_t N = gemm_descs[i].N;
243  const index_t K = gemm_descs[i].K;
244 
245  if(M == 0 || N == 0 || K == 0)
246  {
247  continue;
248  }
249 
250  const index_t stride_a = gemm_descs[i].stride_A;
251  const index_t stride_b = gemm_descs[i].stride_B;
252  const index_t stride_e = gemm_descs[i].stride_C;
253 
254  const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
255 
256  const index_t block_start = grid_size;
257  const index_t block_end = grid_size + grid_size_grp;
258 
259  grid_size += grid_size_grp;
260 
261  auto karg =
262  QuantGroupedGemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
263  type_convert<const BDataType*>(gemm_descs[i].b_ptr),
264  type_convert<const AQDataType*>(gemm_descs[i].aq_ptr),
265  type_convert<const BQDataType*>(gemm_descs[i].bq_ptr),
266  type_convert<CDataType*>(gemm_descs[i].e_ptr),
267  M,
268  N,
269  K,
270  gemm_descs[i].QK_A,
271  gemm_descs[i].QK_B,
272  stride_a,
273  stride_b,
274  stride_e,
275  gemm_descs[i].stride_AQ,
276  gemm_descs[i].stride_BQ,
277  gemm_descs[i].k_batch};
278 
279  gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
280  }
281 
282  return gemm_kernel_args_;
283  }
284 
285  CK_TILE_HOST static bool IsSupportedArgument(const std::vector<QuantGemmTransKernelArg>& kargs)
286  {
287  for(const auto& karg : kargs)
288  {
289  if(!Base::IsSupportedArgument(karg.group_karg))
290  {
291  return false;
292  }
293  }
294  return true;
295  }
296 
297  CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
298  {
299  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
300  }
301 
303  const tuple<index_t, index_t>& block_idx_2d,
304  const index_t block_idx_z) const
305  {
306  const auto [iM, iN] = block_idx_2d;
307 
308  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
309  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
310 
311  const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
312 
313  // options
314  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
315  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
316  const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
317  const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
318  CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
319 
320  // allocate LDS
321  __shared__ char smem_ptr[GetSmemSize()];
322 
323  // Only for BQuantGrouped DoubleSmemBuffer is supported
324  if constexpr(GemmPipeline::DoubleSmemBuffer == true &&
326  {
328  b_ptr,
329  aq_ptr,
330  bq_ptr,
331  c_ptr,
332  smem_ptr,
333  kargs,
334  splitk_batch_offset,
335  i_m,
336  i_n);
337  }
338  else
339  {
340 
341  if constexpr(UsePersistentKernel)
342  {
344  b_ptr,
345  aq_ptr,
346  bq_ptr,
347  c_ptr,
348  smem_ptr,
349  kargs,
350  splitk_batch_offset,
351  i_m,
352  i_n);
353  }
354  else // Non-persistent kernel
355  {
356  Base::RunGemm({a_ptr},
357  {b_ptr},
358  aq_ptr,
359  bq_ptr,
360  c_ptr,
361  smem_ptr,
362  kargs,
363  splitk_batch_offset,
364  i_m,
365  i_n);
366  }
367  }
368  }
369 
370  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
371  CK_TILE_DEVICE static void
373  const BDataType* b_ptr,
374  [[maybe_unused]] const AQDataType* aq_ptr,
375  const BQDataType* bq_ptr,
376  CDataType* c_ptr,
377  void* smem_ptr,
378  const QuantGroupedGemmKernelArgs& kargs,
379  const typename Base::SplitKBatchOffset& splitk_batch_offset,
380  const index_t block_idx_m,
381  const index_t block_idx_n)
382  {
383  static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped");
384 
385  // Create block windows using specialized methods
386  const auto& a_block_window =
387  Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
388  const auto& b_block_window =
389  Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
390  const auto& bq_block_window =
391  Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
392 
393  const index_t num_loop = __builtin_amdgcn_readfirstlane(
394  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
395  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
396 
397  // Run GEMM cooperatively by whole workgroup
398  const auto& c_block_tile = GemmPipeline{}.template operator()(
399  a_block_window, b_block_window, bq_block_window, num_loop, tail_num, smem_ptr);
400 
401  // Run Epilogue Pipeline with split_k dispatch
402  if(kargs.k_batch == 1)
403  {
404  auto c_block_window = Base::template MakeCBlockWindow<memory_operation_enum::set>(
405  c_ptr, kargs, block_idx_m, block_idx_n);
406  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
407  }
408  else
409  {
410  auto c_block_window =
411  Base::template MakeCBlockWindow<memory_operation_enum::atomic_add>(
412  c_ptr, kargs, block_idx_m, block_idx_n);
413  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
414  }
415  }
416 
437  CK_TILE_DEVICE static void
439  const BDataType* b_ptr,
440  const AQDataType* aq_ptr,
441  const BQDataType* bq_ptr,
442  CDataType* c_ptr,
443  void* smem_ptr,
444  const QuantGroupedGemmKernelArgs& kargs,
445  const typename Base::SplitKBatchOffset& splitk_batch_offset,
446  const index_t block_idx_m,
447  const index_t block_idx_n)
448  {
449  // Create block windows using specialized methods
450  const auto& a_block_window =
451  Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
452  const auto& b_block_window =
453  Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
454  const auto& aq_block_window =
455  Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
456  const auto& bq_block_window =
457  Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
458 
459  // Get hot-loop and tail configuration
460  const index_t num_loop = __builtin_amdgcn_readfirstlane(
461  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
462  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
463  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
464 
465  // Run GEMM cooperatively by whole workgroup
466  const auto& c_block_tile = [&]() {
467  if constexpr(kQuantType == QuantType::AQuantGrouped)
468  {
469  return GemmPipeline{}.template operator()(a_block_window,
470  b_block_window,
471  aq_block_window,
472  num_loop,
473  has_hot_loop,
474  tail_num,
475  smem_ptr);
476  }
477  else if constexpr(kQuantType == QuantType::BQuantGrouped)
478  {
479  return GemmPipeline{}.template operator()(a_block_window,
480  b_block_window,
481  bq_block_window,
482  num_loop,
483  has_hot_loop,
484  tail_num,
485  smem_ptr);
486  }
487  else if constexpr(kQuantType == QuantType::ABQuantGrouped)
488  {
489  return GemmPipeline{}.template operator()(a_block_window,
490  b_block_window,
491  aq_block_window,
492  bq_block_window,
493  num_loop,
494  has_hot_loop,
495  tail_num,
496  smem_ptr);
497  }
498  else if constexpr(kQuantType == QuantType::RowColQuant ||
500  {
501  return GemmPipeline{}.template operator()(
502  a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr);
503  }
504  }();
505 
506  // Run Epilogue Pipeline with split_k dispatch
507  if(kargs.k_batch == 1)
508  {
509  auto c_block_window = Base::template MakeCBlockWindow<memory_operation_enum::set>(
510  c_ptr, kargs, block_idx_m, block_idx_n);
511 
512  if constexpr(kQuantType == QuantType::AQuantGrouped ||
515  {
516  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
517  }
518  else if constexpr(kQuantType == QuantType::RowColQuant)
519  {
520  EpiloguePipeline{}(c_block_window,
521  c_block_tile,
522  c_block_window,
523  smem_ptr,
524  aq_block_window,
525  bq_block_window);
526  }
527  else if constexpr(kQuantType == QuantType::TensorQuant)
528  {
529  const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
530  const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
532  c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
533  }
534  }
535  else
536  {
537  auto c_block_window =
538  Base::template MakeCBlockWindow<memory_operation_enum::atomic_add>(
539  c_ptr, kargs, block_idx_m, block_idx_n);
540 
541  if constexpr(kQuantType == QuantType::AQuantGrouped ||
544  {
545  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
546  }
547  else if constexpr(kQuantType == QuantType::RowColQuant)
548  {
549  EpiloguePipeline{}(c_block_window,
550  c_block_tile,
551  c_block_window,
552  smem_ptr,
553  aq_block_window,
554  bq_block_window);
555  }
556  else if constexpr(kQuantType == QuantType::TensorQuant)
557  {
558  const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
559  const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
561  c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
562  }
563  }
564  }
565 
567  index_t block_id,
568  index_t group_count) const
569  {
570  index_t left = 0;
571  index_t right = group_count;
572  index_t group_id = index_t((left + right) >> 1);
573 
574  while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
575  block_id < gemm_desc_ptr[group_id].block_end)) &&
576  left <= right)
577  {
578  if(block_id < gemm_desc_ptr[group_id].block_start)
579  {
580  right = group_id;
581  }
582  else
583  {
584  left = group_id;
585  }
586  group_id = index_t((left + right) >> 1);
587  }
588 
589  return group_id;
590  }
591 
592  // For non-persistent kernels
593  template <bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
594  CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
595  index_t group_count) const
596  {
597  const index_t block_id = ck_tile::get_block_1d_id();
598  const auto gemm_desc_ptr = reinterpret_cast<const QuantGemmTransKernelArg*>(
599  cast_pointer_to_generic_address_space(gemm_descs_const));
600 
601  const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
602  const auto& kargs = gemm_desc_ptr[group_id];
603 
604  const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
605  const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
606  0,
607  kargs.group_karg.M,
608  kargs.group_karg.N,
609  (block_id - kargs.block_start) % grid_size_2d);
610  Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d);
611  }
612 
613  // For persistent kernels
614  template <bool U = UsePersistentKernel,
615  typename = std::enable_if_t<U>,
616  typename = void> // extra template parameter to avoid redefinition
617  CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
618  const index_t group_count) const
619  {
620  const index_t grid_size = ck_tile::get_grid_size();
621  const auto gemm_desc_ptr = reinterpret_cast<const QuantGemmTransKernelArg*>(
622  cast_pointer_to_generic_address_space(gemm_descs_const));
623  index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
624  index_t cum_grid_size = 0;
625  for(index_t group_id = 0; group_id < group_count; ++group_id)
626  {
627  const auto& kargs = gemm_desc_ptr[group_id].group_karg;
628  const auto& k_batch = kargs.k_batch;
629  const auto block_start = cum_grid_size;
630  cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
631  while(block_id < cum_grid_size)
632  {
633  const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
634  const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
635  0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
636  Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
637  block_id = block_id + grid_size; // advance to next block
638  // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
639  if(block_id >= cum_grid_size)
640  {
641  break; // exit the loop if all blocks are processed
642  }
643  }
644  }
645  }
646 };
647 
648 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:23
Definition: cluster_descriptor.hpp:13
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
QuantType
Definition: tile_gemm_quant_traits.hpp:12
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
__device__ index_t get_grid_size()
Definition: get_id.hpp:49
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: amd_address_space.hpp:24
Struct used to calculate offseted tile indexes.
Definition: gemm_tile_partitioner.hpp:184
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: gemm_quant_kernel.hpp:365
index_t splitted_k
Definition: gemm_quant_kernel.hpp:403
Definition: gemm_quant_kernel.hpp:171
index_t k_batch
Definition: gemm_quant_kernel.hpp:187
const void * b_ptr
Definition: gemm_quant_kernel.hpp:173
void * c_ptr
Definition: gemm_quant_kernel.hpp:176
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:174
const void * a_ptr
Definition: gemm_quant_kernel.hpp:172
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:175
index_t QK_A
Definition: gemm_quant_kernel.hpp:180
Definition: gemm_quant_kernel.hpp:195
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_quant_kernel.hpp:1176
static CK_TILE_DEVICE auto MakeABlockWindow(const ADataType *a_ptr, const QuantGemmKernelArgs &kargs, const index_t k_size, const index_t i_m)
Definition: gemm_quant_kernel.hpp:406
static CK_TILE_DEVICE auto MakeBQBlockWindow(const BQDataType *bq_ptr, const QuantGemmKernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:795
static CK_TILE_DEVICE auto MakeBBlockWindow(const BDataType *b_ptr, const QuantGemmKernelArgs &kargs, const index_t k_size, const index_t i_n)
Definition: gemm_quant_kernel.hpp:635
static CK_TILE_HOST bool IsSupportedArgument(const QuantGemmKernelArgs &kargs)
Definition: gemm_quant_kernel.hpp:1023
static CK_TILE_DEVICE auto MakeAQBlockWindow(const AQDataType *aq_ptr, const QuantGemmKernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:472
Definition: grouped_gemm_quant_kernel.hpp:96
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs &&karg, index_t bl_start, index_t bl_end)
Definition: grouped_gemm_quant_kernel.hpp:102
ck_tile::index_t block_end
Definition: grouped_gemm_quant_kernel.hpp:99
ck_tile::index_t block_start
Definition: grouped_gemm_quant_kernel.hpp:98
QuantGroupedGemmKernelArgs group_karg
Definition: grouped_gemm_quant_kernel.hpp:97
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs &&karg)
Definition: grouped_gemm_quant_kernel.hpp:107
The Grouped GEMM kernel host arguments.
Definition: grouped_gemm_quant_kernel.hpp:28
index_t stride_BQ
Definition: grouped_gemm_quant_kernel.hpp:82
const void * b_ptr
Definition: grouped_gemm_quant_kernel.hpp:65
void * c_ptr
Definition: grouped_gemm_quant_kernel.hpp:71
index_t QK_A
Definition: grouped_gemm_quant_kernel.hpp:77
index_t M
Definition: grouped_gemm_quant_kernel.hpp:74
const void * aq_ptr
Definition: grouped_gemm_quant_kernel.hpp:66
index_t stride_B
Definition: grouped_gemm_quant_kernel.hpp:80
index_t k_batch
Definition: grouped_gemm_quant_kernel.hpp:90
index_t N
Definition: grouped_gemm_quant_kernel.hpp:75
index_t stride_AQ
Definition: grouped_gemm_quant_kernel.hpp:81
CK_TILE_HOST QuantGroupedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *e_ptr_, const void *aq_ptr_, const void *bq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_E_, index_t stride_AQ_, index_t stride_BQ_)
Definition: grouped_gemm_quant_kernel.hpp:29
index_t K
Definition: grouped_gemm_quant_kernel.hpp:76
index_t QK_B
Definition: grouped_gemm_quant_kernel.hpp:78
void * e_ptr
Definition: grouped_gemm_quant_kernel.hpp:70
index_t stride_A
Definition: grouped_gemm_quant_kernel.hpp:79
const void * bq_ptr
Definition: grouped_gemm_quant_kernel.hpp:67
index_t stride_C
Definition: grouped_gemm_quant_kernel.hpp:87
index_t stride_E
Definition: grouped_gemm_quant_kernel.hpp:86
const void * a_ptr
Definition: grouped_gemm_quant_kernel.hpp:64
Definition: grouped_gemm_quant_kernel.hpp:118
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: grouped_gemm_quant_kernel.hpp:130
static CK_TILE_HOST auto GridSize(const std::vector< QuantGroupedGemmHostArgs > &gemm_descs)
Definition: grouped_gemm_quant_kernel.hpp:220
static constexpr index_t kBlockSize
Definition: grouped_gemm_quant_kernel.hpp:164
CK_TILE_DEVICE index_t FindGroupId(const QuantGemmTransKernelArg *gemm_desc_ptr, index_t block_id, index_t group_count) const
Definition: grouped_gemm_quant_kernel.hpp:566
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, C/E.
Definition: grouped_gemm_quant_kernel.hpp:133
remove_cvref_t< typename detail::get_bq_data_type_or< GemmPipeline, AccDataType >::type > BQDataType
Definition: grouped_gemm_quant_kernel.hpp:141
static CK_TILE_HOST auto BlockSize() -> dim3
Definition: grouped_gemm_quant_kernel.hpp:191
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: grouped_gemm_quant_kernel.hpp:209
CK_TILE_DEVICE void Run(const QuantGroupedGemmKernelArgs &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
Definition: grouped_gemm_quant_kernel.hpp:302
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: grouped_gemm_quant_kernel.hpp:129
remove_cvref_t< typename detail::get_aq_data_type_or< GemmPipeline, AccDataType >::type > AQDataType
Definition: grouped_gemm_quant_kernel.hpp:139
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_gemm_quant_kernel.hpp:124
static constexpr CK_TILE_HOST_DEVICE auto GetSmemSize() -> index_t
Definition: grouped_gemm_quant_kernel.hpp:297
remove_cvref_t< typename EpiloguePipeline::AccDataType > AccDataType
Definition: grouped_gemm_quant_kernel.hpp:136
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: grouped_gemm_quant_kernel.hpp:128
static CK_TILE_HOST const std::string GetName()
Definition: grouped_gemm_quant_kernel.hpp:167
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: grouped_gemm_quant_kernel.hpp:134
static CK_TILE_HOST auto GetWorkSpaceSize(const std::vector< QuantGroupedGemmHostArgs > &gemm_descs) -> std::size_t
Definition: grouped_gemm_quant_kernel.hpp:181
static CK_TILE_DEVICE void RunGemmWithPipelineSelection(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr, const QuantGroupedGemmKernelArgs &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_gemm_quant_kernel.hpp:438
static CK_TILE_HOST auto GetWorkSpaceSize(index_t group_count) -> std::size_t
Definition: grouped_gemm_quant_kernel.hpp:186
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_gemm_quant_kernel.hpp:123
static CK_TILE_DEVICE void RunGemmWithPipelineSelection2LDS(const ADataType *a_ptr, const BDataType *b_ptr, [[maybe_unused]] const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr, const QuantGroupedGemmKernelArgs &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: grouped_gemm_quant_kernel.hpp:372
CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count) const
Definition: grouped_gemm_quant_kernel.hpp:617
static CK_TILE_HOST auto MakeKargs(const std::vector< QuantGroupedGemmHostArgs > &gemm_descs) -> std::vector< QuantGemmTransKernelArg >
Definition: grouped_gemm_quant_kernel.hpp:231
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_gemm_quant_kernel.hpp:125
static constexpr bool UsePersistentKernel
Definition: grouped_gemm_quant_kernel.hpp:165
static constexpr auto kQuantType
Definition: grouped_gemm_quant_kernel.hpp:143
static CK_TILE_HOST bool IsSupportedArgument(const std::vector< QuantGemmTransKernelArg > &kargs)
Definition: grouped_gemm_quant_kernel.hpp:285
CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE *gemm_descs_const, index_t group_count) const
Definition: grouped_gemm_quant_kernel.hpp:594
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: grouped_gemm_quant_kernel.hpp:135
Definition: stream_config.hpp:30
Definition: tuple.hpp:192