/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File
gemm_quant_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <string>
7 
8 #include "ck_tile/core.hpp"
14 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
19 namespace detail {
20 // Helper templates for safe type extraction
21 template <typename, typename Default, typename = void>
23 {
24  using type = Default;
25 };
26 
27 template <typename T, typename Default>
28 struct get_aq_layout_or<T, Default, std::void_t<typename T::AQLayout>>
29 {
30  using type = typename T::AQLayout;
31 };
32 
33 template <typename, typename Default, typename = void>
35 {
36  using type = Default;
37 };
38 
39 template <typename T, typename Default>
40 struct get_bq_layout_or<T, Default, std::void_t<typename T::BQLayout>>
41 {
42  using type = typename T::BQLayout;
43 };
44 
45 template <typename, typename Default, typename = void>
47 {
48  using type = Default;
49 };
50 
51 template <typename T, typename Default>
52 struct get_aq_data_type_or<T, Default, std::void_t<typename T::AQDataType>>
53 {
54  using type = typename T::AQDataType;
55 };
56 
57 template <typename, typename Default, typename = void>
59 {
60  using type = Default;
61 };
62 
63 template <typename T, typename Default>
64 struct get_bq_data_type_or<T, Default, std::void_t<typename T::BQDataType>>
65 {
66  using type = typename T::BQDataType;
67 };
68 
69 template <typename, typename = void>
71 {
72  static constexpr bool value = false;
73 };
74 
75 template <typename T>
76 struct is_quantpreshuffle_enabled<T, decltype(T::PreshuffleQuant)>
77 {
78  static constexpr bool value = T::PreshuffleQuant;
79 };
80 
81 template <typename, typename = void>
83 {
84  static constexpr bool value = false;
85 };
86 
87 template <typename T>
88 struct is_preshuffleB_enabled<T, std::void_t<decltype(T::PreshuffleB)>>
89 {
90  static constexpr bool value = T::PreshuffleB;
91 };
92 } // namespace detail
93 
95 {
98  index_t N_,
99  index_t K_,
100  index_t QK_A_,
101  index_t QK_B_,
102  index_t stride_A_,
103  index_t stride_B_,
104  index_t stride_C_,
105  index_t stride_AQ_,
106  index_t stride_BQ_)
107  : M(M_),
108  N(N_),
109  K(K_),
110  QK_A(QK_A_),
111  QK_B(QK_B_),
112  stride_A(stride_A_),
113  stride_B(stride_B_),
114  stride_C(stride_C_),
115  stride_AQ(stride_AQ_),
116  stride_BQ(stride_BQ_)
117  {
118  }
119 
130 };
131 
133 {
135  CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_,
136  const void* b_ptr_,
137  void* c_ptr_,
138  const void* aq_ptr_,
139  const void* bq_ptr_,
140  index_t k_batch_,
141  index_t M_,
142  index_t N_,
143  index_t K_,
144  index_t QK_A_,
145  index_t QK_B_,
146  index_t stride_A_,
147  index_t stride_B_,
148  index_t stride_C_,
149  index_t stride_AQ_,
150  index_t stride_BQ_)
152  M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_),
153  a_ptr(a_ptr_),
154  b_ptr(b_ptr_),
155  aq_ptr(aq_ptr_),
156  bq_ptr(bq_ptr_),
157  c_ptr(c_ptr_),
158  k_batch(k_batch_)
159  {
160  }
161 
162  const void* a_ptr = nullptr;
163  const void* b_ptr = nullptr;
164  const void* aq_ptr = nullptr;
165  const void* bq_ptr = nullptr;
166  void* c_ptr = nullptr;
168 };
169 
171 {
172  const void* a_ptr;
173  const void* b_ptr;
174  const void* aq_ptr;
175  const void* bq_ptr;
176  void* c_ptr;
188 };
189 
190 template <typename TilePartitioner_,
191  typename GemmPipeline_,
192  typename EpiloguePipeline_,
193  QuantType QuantType_>
195 {
202 
207 
208  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
209  static constexpr bool PreshuffleQuant =
212 
217 
218  using AQDataType =
220  using BQDataType =
222 
223  static constexpr auto I0 = number<0>(); // A Tensor
224  static constexpr auto I1 = number<1>(); // AQ Tensor
225  static constexpr auto I2 = number<2>(); // B Tensor
226  static constexpr auto I3 = number<3>(); // BQ Tensor
227  static constexpr auto I4 = number<4>(); // C Tensor
228 
229  static constexpr auto kQuantType = QuantType_;
230 
231  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
232  {
233  // clang-format off
234  return concat('_', "gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
235  // clang-format on
236  }
237 
238  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
239  {
240  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
241  }
242 
243  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
244 
245  CK_TILE_HOST static constexpr QuantGemmKernelArgs
247  {
248  return QuantGemmKernelArgs{hostArgs.a_ptr,
249  hostArgs.b_ptr,
250  hostArgs.aq_ptr,
251  hostArgs.bq_ptr,
252  hostArgs.c_ptr,
253  hostArgs.M,
254  hostArgs.N,
255  hostArgs.K,
256  hostArgs.QK_A,
257  hostArgs.QK_B,
258  hostArgs.stride_A,
259  hostArgs.stride_B,
260  hostArgs.stride_C,
261  hostArgs.stride_AQ,
262  hostArgs.stride_BQ,
263  hostArgs.k_batch};
264  }
265 
267  {
268  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
269  }
270 
272  {
273  __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
274  const std::size_t k_id = blockIdx.z)
275  {
276  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2);
277  const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
278  const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
279 
280  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
281  {
283  }
284  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
285  {
286  a_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_A);
287  }
288 
289  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
290  {
291  b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B);
292  }
293  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
294  {
296  }
297 
298  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
299  {
301  }
302  else
303  {
304  splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
305  }
306  }
307 
311  };
312 
314  {
315  if(kargs.k_batch != 1)
316  {
317  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
318  {
319  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
320  }
321  return false;
322  }
323 
324  if constexpr(kQuantType == QuantType::AQuantGrouped)
325  {
326  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
327  if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0)
328  {
329  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
330  {
331  CK_TILE_ERROR("K_A is not a multiple of vector load size for A tensor!");
332  }
333  return false;
334  }
335  }
336 
337  if constexpr(kQuantType == QuantType::BQuantGrouped)
338  {
339  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
340  if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0)
341  {
342  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
343  {
344  CK_TILE_ERROR("K_B is not a multiple of vector load size for B tensor!");
345  }
346  return false;
347  }
348  }
349 
350  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
351  {
352  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
353  GemmPipeline::kPadK == false)
354  {
355  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
356  {
357  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
358  "without padding!");
359  }
360  return false;
361  }
362  if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
363  {
364  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
365  {
366  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
367  }
368  return false;
369  }
370  }
371  else
372  {
373  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
374  {
375  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
376  {
378  "Can't support M that is not a multiple of MPerBlock without padding!");
379  }
380  return false;
381  }
382  if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
383  {
384  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
385  {
386  CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
387  }
388  return false;
389  }
390  }
391 
392  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
393  {
394  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
395  {
396  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
397  {
399  "Can't support N that is not a multiple of NPerBlock without padding!");
400  }
401  return false;
402  }
403  if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
404  {
405  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
406  {
407  CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
408  }
409  return false;
410  }
411  }
412  else
413  {
414  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
415  GemmPipeline::kPadK == false)
416  {
417  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
418  {
419  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
420  "without padding!");
421  }
422  return false;
423  }
424  if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
425  {
426  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
427  {
428  CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
429  }
430  return false;
431  }
432  }
433 
434  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
435  {
436  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
437  {
438  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
439  {
441  "Can't support N that is not a multiple of NPerBlock without padding!");
442  }
443  return false;
444  }
445  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
446  {
447  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
448  {
449  CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
450  }
451  return false;
452  }
453  }
454  else
455  {
456  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
457  {
458  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
459  {
461  "Can't support M that is not a multiple of MPerBlock without padding!");
462  }
463  return false;
464  }
465  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
466  {
467  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
468  {
469  CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
470  }
471  return false;
472  }
473  }
474  return true;
475  }
476 
477  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
478  CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
479  const BDataType* b_ptr,
480  const AQDataType* aq_ptr,
481  const BQDataType* bq_ptr,
482  CDataType* c_ptr,
483  const QuantGemmKernelArgs& kargs,
484  const SplitKBatchOffset& splitk_batch_offset)
485  {
486  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
487  const auto& a_tensor_view = [&]() {
488  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
489  {
490  return make_naive_tensor_view<address_space_enum::global>(
491  a_ptr,
492  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
493  make_tuple(kargs.stride_A, 1),
494  number<GemmPipeline::GetVectorSizeA()>{},
495  number<1>{});
496  }
497  else
498  {
499  return make_naive_tensor_view<address_space_enum::global>(
500  a_ptr,
501  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
502  make_tuple(kargs.stride_A, 1),
503  number<GemmPipeline::GetVectorSizeA()>{},
504  number<1>{});
505  }
506  }();
507 
508  const auto get_padding_size = [](index_t length, index_t alignment) {
509  return ck_tile::integer_least_multiple(length, alignment) - length;
510  };
511 
512  const auto& aq_tensor_view = [&]() {
514  {
515  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
516  const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
517  const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
518 
519  const auto aq_desc =
521  make_tuple(aq_x, 1),
522  number<GemmPipeline::GetVectorSizeAQ()>{},
523  number<1>{});
524 
525  const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
526  const auto aq_pad0_desc = transform_tensor_descriptor(
527  aq_desc,
528  make_tuple(
530  make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
533 
534  const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
535  const auto wave_tile_size =
536  TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
537  const auto wave_tile_count_x =
538  ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
539  const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
540  aq_pad0_desc,
541  make_tuple(
543  make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
546 
547  const auto aq_pad1_desc = transform_tensor_descriptor(
548  aq_unmerge_pad0_desc,
549  make_tuple(
551  make_pass_through_transform(wave_tile_count_x),
553  wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
556 
557  const auto pad_wave_size =
559  const auto aq_merge_pad1_desc = transform_tensor_descriptor(
560  aq_pad1_desc,
561  make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
562  make_pass_through_transform(pad_wave_size)),
565 
566  return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
567  }
568  else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
569  {
570  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
571  return make_naive_tensor_view<address_space_enum::global>(
572  aq_ptr,
573  make_tuple(kargs.M, kargs.QK_A),
574  make_tuple(kargs.stride_AQ, 1),
575  number<GemmPipeline::GetVectorSizeAQ()>{},
576  number<1>{});
577  }
578  else if constexpr(kQuantType == QuantType::RowColQuant)
579  {
580  return make_naive_tensor_view<address_space_enum::global>(
581  aq_ptr,
582  make_tuple(kargs.M, kargs.N),
583  make_tuple(1, 0), // broadcasting over n
584  number<1>{},
585  number<1>{});
586  }
587  else
588  {
589  return nullptr; // TODO: use some other "empty" type for this
590  }
591  }();
592 
593  const auto& b_tensor_view = [&]() {
594  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
595  {
596  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
597  {
598  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
599  const index_t K0 = splitk_batch_offset.splitted_k / K1;
600  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
601  const auto b_k0_n_k1_desc =
603  make_tuple(kargs.N * K1, K1, I1),
605  number<1>{});
606  const auto b_n_k_desc = transform_tensor_descriptor(
607  b_k0_n_k1_desc,
612  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
613  }
614  else
615  {
616  return make_naive_tensor_view<address_space_enum::global>(
617  b_ptr,
618  make_tuple(splitk_batch_offset.splitted_k, kargs.N),
619  make_tuple(kargs.stride_B, 1),
620  number<GemmPipeline::GetVectorSizeB()>{},
621  number<1>{});
622  }
623  }
624  else
625  {
626  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
627  {
628  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
629  const index_t K0 = splitk_batch_offset.splitted_k / K1;
630  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
631  const auto b_k0_n_k1_desc =
633  make_tuple(kargs.N * K1, K1, I1),
635  number<1>{});
636  const auto b_n_k_desc = transform_tensor_descriptor(
637  b_k0_n_k1_desc,
642  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
643  }
644  else
645  {
646  if constexpr(PreshuffleB)
647  {
648  index_t kFlatK =
649  GemmPipeline::flatKPerWarp *
650  (splitk_batch_offset.splitted_k /
651  TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
652  index_t kFlatN = kargs.N * kargs.K / kFlatK;
653 
654  return make_naive_tensor_view<address_space_enum::global>(
655  b_ptr,
656  make_tuple(kFlatN, kFlatK),
657  make_tuple(kFlatK, 1),
658  number<GemmPipeline::GetVectorSizeB()>{},
659  number<1>{});
660  }
661  else
662  {
663  return make_naive_tensor_view<address_space_enum::global>(
664  b_ptr,
665  make_tuple(kargs.N, splitk_batch_offset.splitted_k),
666  make_tuple(kargs.stride_B, 1),
667  number<GemmPipeline::GetVectorSizeB()>{},
668  number<1>{});
669  }
670  }
671  }
672  }();
673 
674  const auto& bq_tensor_view = [&]() {
675  if constexpr(kQuantType == QuantType::RowColQuant)
676  {
677  return make_naive_tensor_view<address_space_enum::global>(
678  bq_ptr,
679  make_tuple(kargs.M, kargs.N),
680  make_tuple(0, 1), // broadcasting over m
681  number<1>{},
682  number<1>{});
683  }
684  else if constexpr(kQuantType == QuantType::BQuantGrouped)
685  {
686  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
687  return make_naive_tensor_view<address_space_enum::global>(
688  bq_ptr,
689  make_tuple(kargs.N, kargs.QK_B),
690  make_tuple(kargs.stride_BQ, 1),
691  number<GemmPipeline::GetVectorSizeBQ()>{},
692  number<1>{});
693  }
694  else
695  {
696  return nullptr; // TODO: use some other "empty" type for this
697  }
698  }();
699 
700  // TODO: enable vector write for C in ColMajor
701  const auto& c_tensor_view = [&]() {
702  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
703  {
704  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
705  c_ptr,
706  make_tuple(kargs.M, kargs.N),
707  make_tuple(kargs.stride_C, 1),
708  number<EpiloguePipeline::GetVectorSizeC()>{},
709  number<1>{});
710  }
711  else
712  {
713  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
714  c_ptr,
715  make_tuple(kargs.M, kargs.N),
716  make_tuple(1, kargs.stride_C),
717  number<1>{},
718  number<1>{});
719  }
720  }();
721 
722  return make_tuple(
723  a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view);
724  }
725 
726  template <typename TensorView>
727  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
728  {
729  const auto& a_pad_view = [&]() {
730  const auto& a_tensor_view = views.at(I0);
731  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
732  {
733  return pad_tensor_view(a_tensor_view,
737  }
738  else
739  {
740  return pad_tensor_view(a_tensor_view,
744  }
745  }();
746 
747  // no padding
748  const auto& aq_pad_view = [&]() { return views.at(I1); }();
749 
750  const auto& b_flat_view = views.at(I2); // not applying any padding to flat B view
751 
752  const auto& b_pad_view = [&]() {
753  const auto& b_tensor_view = views.at(I2);
754  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
755  {
756  return pad_tensor_view(b_tensor_view,
760  }
761  else
762  {
763  return pad_tensor_view(b_tensor_view,
767  }
768  }();
769 
770  // no padding
771  const auto& bq_pad_view = [&]() { return views.at(I3); }();
772 
773  // TODO vector write in for C in ColMajor
774  const auto& c_pad_view = [&]() {
775  const auto& c_tensor_view = views.at(I4);
776  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
777  {
778  return pad_tensor_view(c_tensor_view,
782  }
783  else
784  {
785  return pad_tensor_view(c_tensor_view,
789  }
790  }();
791  if constexpr(PreshuffleB)
792  {
793  return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view);
794  }
795  else
796  {
797  return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view);
798  }
799  }
800 
801  template <typename PadView>
802  CK_TILE_DEVICE static auto
803  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
804  {
805  const auto& a_pad_view = views.at(I0);
806  const auto& aq_pad_view = views.at(I1);
807  const auto& b_pad_view = views.at(I2);
808  const auto& bq_pad_view = views.at(I3);
809  const auto& c_pad_view = views.at(I4);
810  const auto& a_block_window = [&]() {
811  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
812  {
813  return make_tile_window(a_pad_view,
816  {i_m, 0});
817  }
818  else
819  {
820  return make_tile_window(a_pad_view,
823  {0, i_m});
824  }
825  }();
826 
827  const auto& aq_block_window = [&]() {
829  {
830  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
831  constexpr auto block_m = TilePartitioner::MPerBlock;
832  constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
833  constexpr auto aqk_per_block =
834  TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
835  constexpr auto tile_window_width =
836  ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
837  constexpr auto tile_window_height = block_m / warp_m;
838  auto block_m_idx = i_m / block_m;
839  return make_tile_window(
840  aq_pad_view,
842  {block_m_idx * tile_window_height, 0});
843  }
844  else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
845  {
846  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
847  constexpr auto block_m = TilePartitioner::MPerBlock;
848  constexpr auto block_k = TilePartitioner::KPerBlock;
849  return make_tile_window(
850  aq_pad_view,
851  make_tuple(number<block_m>{}, number<block_k / GemmPipeline::QuantGroupSize>{}),
852  {i_m, 0});
853  }
854  else if constexpr(kQuantType == QuantType::RowColQuant)
855  {
856  return make_tile_window(aq_pad_view,
859  {i_m, i_n});
860  }
861  else
862  {
863  return nullptr; // TODO: use some other "empty" type?
864  }
865  }();
866 
867  const auto& b_block_window = [&]() {
868  if constexpr(PreshuffleB)
869  {
870  return make_tile_window(
871  b_pad_view,
874  {static_cast<int>(i_n / TilePartitioner::BlockGemmShape::WarpTile::at(I1)), 0});
875  }
876  else
877  {
878  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
879  {
880  return make_tile_window(b_pad_view,
883  {i_n, 0});
884  }
885  else
886  {
887  return make_tile_window(b_pad_view,
890  {0, i_n});
891  }
892  }
893  }();
894 
895  const auto& bq_block_window = [&]() {
896  if constexpr(kQuantType == QuantType::RowColQuant)
897  {
898  return make_tile_window(bq_pad_view,
901  {i_m, i_n});
902  }
903  else if constexpr(kQuantType == QuantType::BQuantGrouped)
904  {
905  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
906  return make_tile_window(
907  bq_pad_view,
909  number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
910  {i_n, 0});
911  }
912  else
913  {
914  return nullptr; // TODO: use some other "empty" type here
915  }
916  }();
917 
918  auto c_block_window = make_tile_window(
919  c_pad_view,
921  {i_m, i_n});
922 
923  return make_tuple(
924  a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window);
925  }
926 
943  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
944  CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
945  const BDataType* b_ptr,
946  const AQDataType* aq_ptr,
947  const BQDataType* bq_ptr,
948  CDataType* c_ptr,
949  void* smem_ptr_0,
950  const QuantGemmKernelArgs& kargs,
951  const SplitKBatchOffset& splitk_batch_offset,
952  const index_t block_idx_m,
953  const index_t block_idx_n)
954  {
955  // Create Gemm tensor views, pad views and tile windows
956  const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
957  a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
958 
959  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
960  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
961 
962  const index_t num_loop =
963  amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
964 
965  // Run GEMM cooperatively by whole workgroup.
966  const auto& a_block_window = gemm_tile_windows.at(I0);
967  const auto& b_block_window = gemm_tile_windows.at(I2);
968 
969  const auto& c_block_tile = [&]() {
970  if constexpr(kQuantType == QuantType::AQuantGrouped)
971  {
972  const auto& aq_block_window = gemm_tile_windows.at(I1);
973  return GemmPipeline{}.template operator()(
974  a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
975  }
976  else if constexpr(kQuantType == QuantType::BQuantGrouped)
977  {
978  const auto& bq_block_window = gemm_tile_windows.at(I3);
979  return GemmPipeline{}.template operator()(
980  a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
981  }
982  else if constexpr(kQuantType == QuantType::RowColQuant ||
984  {
985  return GemmPipeline{}.template operator()(
986  a_block_window, b_block_window, num_loop, smem_ptr_0);
987  }
988  }();
989 
990  // Run Epilogue Pipeline
991  auto& c_block_window = gemm_tile_windows.at(I4);
992 
993  if constexpr(kQuantType == QuantType::AQuantGrouped ||
995  {
996  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
997  }
998  else if constexpr(kQuantType == QuantType::RowColQuant)
999  {
1000  const auto& aq_block_window = gemm_tile_windows.at(I1);
1001  const auto& bq_block_window = gemm_tile_windows.at(I3);
1002  EpiloguePipeline{}(c_block_window,
1003  c_block_tile,
1004  c_block_window,
1005  smem_ptr_0,
1006  aq_block_window,
1007  bq_block_window);
1008  }
1009  else if constexpr(kQuantType == QuantType::TensorQuant)
1010  {
1011  // TODO: why doesn't readfirstlane work here?
1012  // const AccDataType aq_scale =
1013  // __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*aq_ptr));
1014  // const AccDataType bq_scale =
1015  // __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*bq_ptr));
1016  const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
1017  const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
1018  EpiloguePipeline{}(
1019  c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
1020  }
1021  }
1037  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
1038  CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
1039  const BDataType* b_ptr,
1040  const AQDataType* aq_ptr,
1041  const BQDataType* bq_ptr,
1042  CDataType* c_ptr,
1043  void* smem_ptr_0,
1044  void* smem_ptr_1,
1045  const QuantGemmKernelArgs& kargs,
1046  const SplitKBatchOffset& splitk_batch_offset,
1047  const index_t block_idx_m,
1048  const index_t block_idx_n)
1049  {
1050  // Create Gemm tensor views, pad views and tile windows
1051  const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
1052  a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
1053 
1054  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
1055  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
1056 
1057  const index_t num_loop = __builtin_amdgcn_readfirstlane(
1058  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
1059 
1060  // Run GEMM cooperatively by whole workgroup.
1061  const auto& a_block_window = gemm_tile_windows.at(I0);
1062  const auto& b_block_window = gemm_tile_windows.at(I2);
1063 
1064  const auto& c_block_tile = [&]() {
1065  if constexpr(kQuantType == QuantType::BQuantGrouped)
1066  {
1067  const auto& bq_block_window = gemm_tile_windows.at(I3);
1068  return GemmPipeline{}.template operator()(a_block_window,
1069  b_block_window,
1070  bq_block_window,
1071  num_loop,
1072  smem_ptr_0,
1073  smem_ptr_1);
1074  }
1075  else
1076  {
1077  return nullptr;
1078  }
1079  }();
1080 
1081  // Run Epilogue Pipeline
1082  auto& c_block_window = gemm_tile_windows.at(I4);
1083 
1084  if constexpr(kQuantType == QuantType::BQuantGrouped)
1085  {
1086  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
1087  }
1088  else
1089  {
1090  return;
1091  // throw std::runtime_error("DoubleSmemBuffer Not implemented for AQuantGrouped or
1092  // RowColQuant"); static_assert(kQuantType == QuantType::BQuantGrouped,
1093  // "DoubleSmemBuffer Not implemented");
1094  }
1095  }
1096 
1098  {
1099  const auto blockId = amd_wave_read_first_lane(blockIdx.x);
1100  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1101  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1102  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1103 
1104  const SplitKBatchOffset splitk_batch_offset(kargs);
1105  // options
1106  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
1107  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
1108  const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
1109  const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
1110  CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
1111 
1112  // allocate LDS
1113  __shared__ char smem_ptr_0[GetSmemSize()];
1114 
1115  assert(kargs.k_batch == 1);
1116  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1117  {
1118  __shared__ char smem_ptr_1[GetSmemSize()];
1119 
1120  RunGemm2LDS(a_ptr,
1121  b_ptr,
1122  aq_ptr,
1123  bq_ptr,
1124  c_ptr,
1125  smem_ptr_0,
1126  smem_ptr_1,
1127  kargs,
1128  splitk_batch_offset,
1129  i_m,
1130  i_n);
1131  }
1132  else
1133  {
1134  RunGemm(a_ptr,
1135  b_ptr,
1136  aq_ptr,
1137  bq_ptr,
1138  c_ptr,
1139  smem_ptr_0,
1140  kargs,
1141  splitk_batch_offset,
1142  i_m,
1143  i_n);
1144  }
1145  }
1146 };
1147 
1148 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1584
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_least_multiple(X x, Y y)
Definition: math.hpp:155
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
QuantType
Definition: tile_gemm_quant_traits.hpp:12
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
unsigned int uint32_t
Definition: stdint.h:126
Definition: gemm_quant_kernel.hpp:133
void * c_ptr
Definition: gemm_quant_kernel.hpp:166
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:164
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:165
const void * b_ptr
Definition: gemm_quant_kernel.hpp:163
CK_TILE_HOST QuantGemmHostArgs()=default
index_t k_batch
Definition: gemm_quant_kernel.hpp:167
const void * a_ptr
Definition: gemm_quant_kernel.hpp:162
CK_TILE_HOST QuantGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, const void *aq_ptr_, const void *bq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:135
Definition: gemm_quant_kernel.hpp:272
__device__ SplitKBatchOffset(const QuantGemmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_quant_kernel.hpp:273
index_t a_k_split_offset
Definition: gemm_quant_kernel.hpp:308
index_t b_k_split_offset
Definition: gemm_quant_kernel.hpp:309
index_t splitted_k
Definition: gemm_quant_kernel.hpp:310
Definition: gemm_quant_kernel.hpp:171
index_t k_batch
Definition: gemm_quant_kernel.hpp:187
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:186
const void * b_ptr
Definition: gemm_quant_kernel.hpp:173
void * c_ptr
Definition: gemm_quant_kernel.hpp:176
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:174
index_t stride_A
Definition: gemm_quant_kernel.hpp:182
index_t M
Definition: gemm_quant_kernel.hpp:177
const void * a_ptr
Definition: gemm_quant_kernel.hpp:172
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:175
index_t QK_B
Definition: gemm_quant_kernel.hpp:181
index_t K
Definition: gemm_quant_kernel.hpp:179
index_t QK_A
Definition: gemm_quant_kernel.hpp:180
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:185
index_t N
Definition: gemm_quant_kernel.hpp:178
index_t stride_C
Definition: gemm_quant_kernel.hpp:184
index_t stride_B
Definition: gemm_quant_kernel.hpp:183
Definition: gemm_quant_kernel.hpp:195
static constexpr auto I4
Definition: gemm_quant_kernel.hpp:227
static constexpr auto I3
Definition: gemm_quant_kernel.hpp:226
static constexpr bool PreshuffleB
Definition: gemm_quant_kernel.hpp:211
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_quant_kernel.hpp:238
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_quant_kernel.hpp:197
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_quant_kernel.hpp:198
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: gemm_quant_kernel.hpp:727
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_quant_kernel.hpp:196
remove_cvref_t< typename EpiloguePipeline::AccDataType > AccDataType
Definition: gemm_quant_kernel.hpp:216
static constexpr auto I0
Definition: gemm_quant_kernel.hpp:223
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
Definition: gemm_quant_kernel.hpp:1097
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_quant_kernel.hpp:215
static constexpr index_t kBlockSize
Definition: gemm_quant_kernel.hpp:208
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_quant_kernel.hpp:200
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_quant_kernel.hpp:201
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: gemm_quant_kernel.hpp:478
static constexpr auto I1
Definition: gemm_quant_kernel.hpp:224
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_quant_kernel.hpp:199
static constexpr bool PreshuffleQuant
Definition: gemm_quant_kernel.hpp:209
static CK_TILE_HOST bool IsSupportedArgument(const QuantGemmKernelArgs &kargs)
Definition: gemm_quant_kernel.hpp:313
remove_cvref_t< typename detail::get_aq_data_type_or< GemmPipeline, AccDataType >::type > AQDataType
Definition: gemm_quant_kernel.hpp:219
remove_cvref_t< typename detail::get_bq_data_type_or< GemmPipeline, AccDataType >::type > BQDataType
Definition: gemm_quant_kernel.hpp:221
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_quant_kernel.hpp:214
static constexpr auto I2
Definition: gemm_quant_kernel.hpp:225
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_quant_kernel.hpp:266
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_quant_kernel.hpp:944
static constexpr CK_TILE_HOST QuantGemmKernelArgs MakeKernelArgs(const QuantGemmHostArgs &hostArgs)
Definition: gemm_quant_kernel.hpp:246
static CK_TILE_HOST const std::string GetName()
Definition: gemm_quant_kernel.hpp:231
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:803
remove_cvref_t< typename detail::get_bq_layout_or< GemmPipeline, typename GemmPipeline::BLayout >::type > BQLayout
Definition: gemm_quant_kernel.hpp:206
static CK_TILE_DEVICE void RunGemm2LDS(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, void *smem_ptr_1, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_quant_kernel.hpp:1038
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_quant_kernel.hpp:213
remove_cvref_t< typename detail::get_aq_layout_or< GemmPipeline, typename GemmPipeline::ALayout >::type > AQLayout
Definition: gemm_quant_kernel.hpp:204
static constexpr auto kQuantType
Definition: gemm_quant_kernel.hpp:229
static constexpr CK_TILE_HOST auto BlockSize()
Definition: gemm_quant_kernel.hpp:243
Definition: gemm_quant_kernel.hpp:95
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:128
index_t N
Definition: gemm_quant_kernel.hpp:121
index_t K
Definition: gemm_quant_kernel.hpp:122
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:129
index_t stride_C
Definition: gemm_quant_kernel.hpp:127
index_t stride_B
Definition: gemm_quant_kernel.hpp:126
index_t stride_A
Definition: gemm_quant_kernel.hpp:125
CK_TILE_HOST QuantGemmProblem(index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:97
index_t QK_A
Definition: gemm_quant_kernel.hpp:123
index_t QK_B
Definition: gemm_quant_kernel.hpp:124
CK_TILE_HOST QuantGemmProblem()=default
index_t M
Definition: gemm_quant_kernel.hpp:120
Definition: integral_constant.hpp:13
Definition: gemm_quant_kernel.hpp:47
Default type
Definition: gemm_quant_kernel.hpp:48
typename T::AQLayout type
Definition: gemm_quant_kernel.hpp:30
Definition: gemm_quant_kernel.hpp:23
Default type
Definition: gemm_quant_kernel.hpp:24
Definition: gemm_quant_kernel.hpp:59
Default type
Definition: gemm_quant_kernel.hpp:60
typename T::BQLayout type
Definition: gemm_quant_kernel.hpp:42
Definition: gemm_quant_kernel.hpp:35
Default type
Definition: gemm_quant_kernel.hpp:36
Definition: gemm_quant_kernel.hpp:83
static constexpr bool value
Definition: gemm_quant_kernel.hpp:84
Definition: gemm_quant_kernel.hpp:71
static constexpr bool value
Definition: gemm_quant_kernel.hpp:72
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145