/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File
gemm_quant_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <string>
7 
8 #include "ck_tile/core.hpp"
14 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
19 namespace detail {
20 // Helper templates for safe type extraction
21 template <typename, typename Default, typename = void>
23 {
24  using type = Default;
25 };
26 
27 template <typename T, typename Default>
28 struct get_aq_layout_or<T, Default, std::void_t<typename T::AQLayout>>
29 {
30  using type = typename T::AQLayout;
31 };
32 
33 template <typename, typename Default, typename = void>
35 {
36  using type = Default;
37 };
38 
39 template <typename T, typename Default>
40 struct get_bq_layout_or<T, Default, std::void_t<typename T::BQLayout>>
41 {
42  using type = typename T::BQLayout;
43 };
44 
45 template <typename, typename Default, typename = void>
47 {
48  using type = Default;
49 };
50 
51 template <typename T, typename Default>
52 struct get_aq_data_type_or<T, Default, std::void_t<typename T::AQDataType>>
53 {
54  using type = typename T::AQDataType;
55 };
56 
57 template <typename, typename Default, typename = void>
59 {
60  using type = Default;
61 };
62 
63 template <typename T, typename Default>
64 struct get_bq_data_type_or<T, Default, std::void_t<typename T::BQDataType>>
65 {
66  using type = typename T::BQDataType;
67 };
68 
69 template <typename, typename = void>
71 {
72  static constexpr bool value = false;
73 };
74 
75 template <typename T>
76 struct is_quantpreshuffle_enabled<T, std::void_t<decltype(T::PreshuffleQuant)>>
77 {
78  static constexpr bool value = T::PreshuffleQuant;
79 };
80 
81 template <typename, typename = void>
83 {
84  static constexpr bool value = false;
85 };
86 
87 template <typename T>
88 struct is_preshuffleB_enabled<T, std::void_t<decltype(T::PreshuffleB)>>
89 {
90  static constexpr bool value = T::PreshuffleB;
91 };
92 } // namespace detail
93 
95 {
98  index_t N_,
99  index_t K_,
100  index_t QK_A_,
101  index_t QK_B_,
102  index_t stride_A_,
103  index_t stride_B_,
104  index_t stride_C_,
105  index_t stride_AQ_,
106  index_t stride_BQ_)
107  : M(M_),
108  N(N_),
109  K(K_),
110  QK_A(QK_A_),
111  QK_B(QK_B_),
112  stride_A(stride_A_),
113  stride_B(stride_B_),
114  stride_C(stride_C_),
115  stride_AQ(stride_AQ_),
116  stride_BQ(stride_BQ_)
117  {
118  }
119 
130 };
131 
133 {
135  CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_,
136  const void* b_ptr_,
137  void* c_ptr_,
138  const void* aq_ptr_,
139  const void* bq_ptr_,
140  index_t k_batch_,
141  index_t M_,
142  index_t N_,
143  index_t K_,
144  index_t QK_A_,
145  index_t QK_B_,
146  index_t stride_A_,
147  index_t stride_B_,
148  index_t stride_C_,
149  index_t stride_AQ_,
150  index_t stride_BQ_)
152  M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_),
153  a_ptr(a_ptr_),
154  b_ptr(b_ptr_),
155  aq_ptr(aq_ptr_),
156  bq_ptr(bq_ptr_),
157  c_ptr(c_ptr_),
158  k_batch(k_batch_)
159  {
160  }
161 
162  const void* a_ptr = nullptr;
163  const void* b_ptr = nullptr;
164  const void* aq_ptr = nullptr;
165  const void* bq_ptr = nullptr;
166  void* c_ptr = nullptr;
168 };
169 
171 {
172  const void* a_ptr;
173  const void* b_ptr;
174  const void* aq_ptr;
175  const void* bq_ptr;
176  void* c_ptr;
188 };
189 
190 template <typename TilePartitioner_,
191  typename GemmPipeline_,
192  typename EpiloguePipeline_,
193  QuantType QuantType_>
195 {
202 
207 
208  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
209  static constexpr bool PreshuffleQuant =
212 
217 
218  using AQDataType =
220  using BQDataType =
222 
223  static constexpr auto I0 = number<0>(); // A Tensor
224  static constexpr auto I1 = number<1>(); // AQ Tensor
225  static constexpr auto I2 = number<2>(); // B Tensor
226  static constexpr auto I3 = number<3>(); // BQ Tensor
227  static constexpr auto I4 = number<4>(); // C Tensor
228 
229  static constexpr auto kQuantType = QuantType_;
230 
231  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
232  {
233  // clang-format off
234  return concat('_', "gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
235  // clang-format on
236  }
237 
238  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
239  {
240  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
241  }
242 
243  CK_TILE_HOST static auto BlockSize()
244  {
245  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
246  }
247 
248  CK_TILE_HOST static constexpr QuantGemmKernelArgs
250  {
251  return QuantGemmKernelArgs{hostArgs.a_ptr,
252  hostArgs.b_ptr,
253  hostArgs.aq_ptr,
254  hostArgs.bq_ptr,
255  hostArgs.c_ptr,
256  hostArgs.M,
257  hostArgs.N,
258  hostArgs.K,
259  hostArgs.QK_A,
260  hostArgs.QK_B,
261  hostArgs.stride_A,
262  hostArgs.stride_B,
263  hostArgs.stride_C,
264  hostArgs.stride_AQ,
265  hostArgs.stride_BQ,
266  hostArgs.k_batch};
267  }
268 
270  {
271  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
272  }
273 
274  private:
275  CK_TILE_DEVICE static constexpr index_t get_padding_size(index_t length, index_t alignment)
276  {
277  return ck_tile::integer_least_multiple(length, alignment) - length;
278  };
279  // ===================================================================
280  // Helper: Create Pre-shuffled Quantization Tensor Descriptor
281  // ===================================================================
282  template <index_t KPerBlockBQ,
283  index_t NPerBlockBQ,
284  index_t NPerBlock,
285  index_t WarpTileN,
286  index_t GetVectorSizeBQ,
287  typename BQDataType_>
288  CK_TILE_DEVICE static auto
289  MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QN_B, index_t QK_B)
290  {
291  // Step 1: Calculate base BQ tensor dimensions
292  // ----------------------------------------------------------
293  // bq_x: Number of quantization groups in N dimension
294  // = N * KPerBlockBQ, where KPerBlockBQ is the number of
295  // K-dimension groups per block
296  // bq_y: Number of quantization groups in K dimension
297  // = Total K groups (QK_B) / groups per block
298  const auto bq_x = N * KPerBlockBQ;
299  const auto bq_y = QK_B / KPerBlockBQ;
300 
301  const auto bq_desc = make_naive_tensor_descriptor(
302  make_tuple(bq_y, bq_x), make_tuple(bq_x, 1), number<GetVectorSizeBQ>{}, number<1>{});
303 
304  // Step 2: First padding transformation (block-level alignment)
305  // ----------------------------------------------------------
306  // Pad the X dimension to be a multiple of block_tile_size to ensure
307  // each thread block can process complete tiles without edge cases
308  const auto block_tile_size = NPerBlockBQ * KPerBlockBQ;
309 
310  const auto bq_pad0_desc = transform_tensor_descriptor(
311  bq_desc,
313  make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))),
314  make_tuple(sequence<0>{}, sequence<1>{}),
315  make_tuple(sequence<0>{}, sequence<1>{}));
316 
317  // Step 3: Unmerge transformation (wave-level decomposition)
318  // ----------------------------------------------------------
319  // Split the X dimension into [wave_tile_count_x, wave_tile_size]
320  // This separates the work into tiles that can be processed by
321  // individual warps/waves
322  const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1];
323  const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ;
324  const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size);
325 
326  const auto bq_unmerge_pad0_desc = transform_tensor_descriptor(
327  bq_pad0_desc,
329  make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
330  make_tuple(sequence<0>{}, sequence<1>{}),
331  make_tuple(sequence<0>{}, sequence<1, 2>{}));
332 
333  // Step 4: Second padding transformation (warp-level alignment)
334  // ----------------------------------------------------------
335  // Pad wave_tile_size to be a multiple of warp_size (typically 32 or 64)
336  // This ensures coalesced memory accesses within each warp
337  const auto bq_pad1_desc = transform_tensor_descriptor(
338  bq_unmerge_pad0_desc,
340  make_pass_through_transform(wave_tile_count_x),
341  make_right_pad_transform(wave_tile_size,
342  get_padding_size(wave_tile_size, get_warp_size()))),
343  make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
344  make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
345 
346  // Step 5: Final merge transformation (prepare for indexing)
347  // ----------------------------------------------------------
348  // Merge [bq_y, wave_tile_count_x] into a single outer dimension
349  // This creates a 2D layout: [merged_outer_dim, pad_wave_size]
350  // where merged_outer_dim = bq_y * wave_tile_count_x
351  // This layout facilitates efficient block-to-data mapping
352  const auto pad_wave_size = ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
353  const auto bq_merge_pad1_desc = transform_tensor_descriptor(
354  bq_pad1_desc,
355  make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)),
356  make_pass_through_transform(pad_wave_size)),
357  make_tuple(sequence<0, 1>{}, sequence<2>{}),
358  make_tuple(sequence<0>{}, sequence<1>{}));
359 
360  return make_tensor_view<address_space_enum::global>(bq_ptr, bq_merge_pad1_desc);
361  }
362 
363  public:
365  {
366  __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
367  const std::size_t k_id = blockIdx.z)
368  {
369  constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
370  const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
371  const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
372 
373  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
374  {
376  }
377  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
378  {
379  a_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_A);
380  }
381 
382  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
383  {
384  b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B);
385  }
386  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
387  {
389  }
390 
391  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
392  {
394  }
395  else
396  {
397  splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
398  }
399  }
400 
404  };
405 
406  CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
407  const QuantGemmKernelArgs& kargs,
408  const index_t k_size,
409  const index_t i_m)
410  {
411  // Step 1: Create tensor view for A
412  const auto& a_tensor_view = [&]() {
413  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
414  {
415  return make_naive_tensor_view<address_space_enum::global>(
416  a_ptr,
417  make_tuple(kargs.M, k_size),
418  make_tuple(kargs.stride_A, 1),
419  number<GemmPipeline::GetVectorSizeA()>{},
420  number<1>{});
421  }
422  else
423  {
424  return make_naive_tensor_view<address_space_enum::global>(
425  a_ptr,
426  make_tuple(k_size, kargs.M),
427  make_tuple(kargs.stride_A, 1),
428  number<GemmPipeline::GetVectorSizeA()>{},
429  number<1>{});
430  }
431  }();
432 
433  // Step 2: Create padded view
434  const auto& a_pad_view = [&]() {
435  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
436  {
437  return pad_tensor_view(a_tensor_view,
441  }
442  else
443  {
444  return pad_tensor_view(a_tensor_view,
448  }
449  }();
450 
451  // Step 3: Create tile window
452  const auto& a_block_window = [&]() {
453  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
454  {
455  return make_tile_window(a_pad_view,
458  {i_m, 0});
459  }
460  else
461  {
462  return make_tile_window(a_pad_view,
465  {0, i_m});
466  }
467  }();
468 
469  return a_block_window;
470  }
471 
472  CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr,
473  const QuantGemmKernelArgs& kargs,
474  const index_t i_m,
475  const index_t i_n)
476  {
477  // Step 1: Create tensor view for AQ
478  const auto& aq_tensor_view = [&]() {
480  {
481  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
482  const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
483  const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
484  const auto aq_desc =
486  make_tuple(aq_x, 1),
487  number<GemmPipeline::GetVectorSizeAQ()>{},
488  number<1>{});
489 
490  const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
491  const auto aq_pad0_desc = transform_tensor_descriptor(
492  aq_desc,
493  make_tuple(
495  make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
498 
499  const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
500  const auto wave_tile_size =
501  GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
502  const auto wave_tile_count_x =
503  ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
504 
505  const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
506  aq_pad0_desc,
507  make_tuple(
509  make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
512 
513  const auto aq_pad1_desc = transform_tensor_descriptor(
514  aq_unmerge_pad0_desc,
515  make_tuple(
517  make_pass_through_transform(wave_tile_count_x),
519  wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
522 
523  const auto pad_wave_size =
525  const auto aq_merge_pad1_desc = transform_tensor_descriptor(
526  aq_pad1_desc,
527  make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
528  make_pass_through_transform(pad_wave_size)),
531 
532  return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
533  }
534  else if constexpr((kQuantType == QuantType::AQuantGrouped ||
537  {
538  if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
539  {
540  return make_naive_tensor_view<address_space_enum::global>(
541  aq_ptr,
542  make_tuple(kargs.M, kargs.QK_A),
543  make_tuple(kargs.stride_AQ, 1),
544  number<GemmPipeline::GetVectorSizeAQ()>{},
545  number<1>{});
546  }
547  else // Column major AQ
548  {
549  return make_naive_tensor_view<address_space_enum::global>(
550  aq_ptr,
551  make_tuple(kargs.QK_A, kargs.M),
552  make_tuple(kargs.stride_AQ, 1),
553  number<GemmPipeline::GetVectorSizeAQ()>{},
554  number<1>{});
555  }
556  }
557  else if constexpr(kQuantType == QuantType::RowColQuant)
558  {
559  return make_naive_tensor_view<address_space_enum::global>(
560  aq_ptr,
561  make_tuple(kargs.M, kargs.N),
562  make_tuple(1, 0), // broadcasting over n
563  number<1>{},
564  number<1>{});
565  }
566  else
567  {
568  return nullptr;
569  }
570  }();
571 
572  // Step 2: Create tile window (no padding for AQ)
573  const auto& aq_block_window = [&]() {
575  {
576  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
578  constexpr auto block_m = TilePartitioner::MPerBlock;
579  constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
580  constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
581  constexpr auto tile_window_width =
582  ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
583  constexpr auto tile_window_height = block_m / warp_m;
584  auto block_m_idx = i_m / block_m;
585  return make_tile_window(
586  aq_tensor_view,
588  {block_m_idx * tile_window_height, 0});
589  }
590  else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
591  {
593  constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
594  constexpr auto block_m = TilePartitioner::MPerBlock;
595  if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
596  {
597  return make_tile_window(aq_tensor_view,
599  {i_m, 0});
600  }
601  else // Column major AQ
602  {
603  return make_tile_window(aq_tensor_view,
605  {0, i_m});
606  }
607  }
608  else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant)
609  {
610  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
612  constexpr auto block_m = TilePartitioner::MPerBlock;
613  constexpr auto block_k = TilePartitioner::KPerBlock;
614  return make_tile_window(
615  aq_tensor_view,
616  make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
617  {i_m, 0});
618  }
619  else if constexpr(kQuantType == QuantType::RowColQuant)
620  {
621  return make_tile_window(aq_tensor_view,
624  {i_m, i_n});
625  }
626  else
627  {
628  return nullptr;
629  }
630  }();
631 
632  return aq_block_window;
633  }
634 
635  CK_TILE_DEVICE static auto MakeBBlockWindow(const BDataType* b_ptr,
636  const QuantGemmKernelArgs& kargs,
637  const index_t k_size,
638  const index_t i_n)
639  {
640  // Step 1: Create tensor view for B
641  const auto& b_tensor_view = [&]() {
642  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
643  {
644  if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
645  {
646  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
647  const index_t K0 = k_size / K1;
648  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
649  const auto b_k0_n_k1_desc =
651  make_tuple(kargs.N * K1, K1, I1),
653  number<1>{});
654  const auto b_n_k_desc = transform_tensor_descriptor(
655  b_k0_n_k1_desc,
660  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
661  }
662  else
663  {
664  return make_naive_tensor_view<address_space_enum::global>(
665  b_ptr,
666  make_tuple(k_size, kargs.N),
667  make_tuple(kargs.stride_B, 1),
668  number<GemmPipeline::GetVectorSizeB()>{},
669  number<1>{});
670  }
671  }
672  else
673  {
674  if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
675  {
676  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
677  const index_t K0 = k_size / K1;
678  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
679  const auto b_k0_n_k1_desc =
681  make_tuple(kargs.N * K1, K1, I1),
683  number<1>{});
684  const auto b_n_k_desc = transform_tensor_descriptor(
685  b_k0_n_k1_desc,
690  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
691  }
692  else
693  {
694  if constexpr(PreshuffleB)
695  {
696  index_t kFlatK =
697  GemmPipeline::flatKPerWarp *
698  (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
699  index_t kFlatN = kargs.N * kargs.K / kFlatK;
700  return make_naive_tensor_view<address_space_enum::global>(
701  b_ptr,
702  make_tuple(kFlatN, kFlatK),
703  make_tuple(kFlatK, 1),
704  number<GemmPipeline::GetVectorSizeB()>{},
705  number<1>{});
706  }
707  else
708  {
709  if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
710  return make_naive_tensor_view<address_space_enum::global>(
711  b_ptr,
712  make_tuple(kargs.N, k_size / 2),
713  make_tuple(kargs.stride_B, 1),
714  number<GemmPipeline::GetVectorSizeB()>{},
715  number<1>{});
716  else
717  return make_naive_tensor_view<address_space_enum::global>(
718  b_ptr,
719  make_tuple(kargs.N, k_size),
720  make_tuple(kargs.stride_B, 1),
721  number<GemmPipeline::GetVectorSizeB()>{},
722  number<1>{});
723  }
724  }
725  }
726  }();
727 
728  // Step 2: Create padded view (or flat view for PreshuffleB)
729  const auto& b_pad_view = [&]() {
730  if constexpr(PreshuffleB)
731  {
732  return b_tensor_view; // no padding for preshuffle
733  }
734  else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
735  {
736  if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
737  return pad_tensor_view(b_tensor_view,
739  number<TilePartitioner::KPerBlock / 2>{}),
741  else
742  return pad_tensor_view(b_tensor_view,
746  }
747  else
748  {
749  return pad_tensor_view(b_tensor_view,
753  }
754  }();
755 
756  // Step 3: Create tile window
757  const auto& b_block_window = [&]() {
758  if constexpr(PreshuffleB)
759  {
760  return make_tile_window(
761  b_pad_view,
764  {static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0});
765  }
766  else
767  {
768  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
769  {
770  if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
771  return make_tile_window(
772  b_pad_view,
774  number<TilePartitioner::KPerBlock / 2>{}),
775  {i_n, 0});
776  else
777  return make_tile_window(b_pad_view,
780  {i_n, 0});
781  }
782  else
783  {
784  return make_tile_window(b_pad_view,
787  {0, i_n});
788  }
789  }
790  }();
791 
792  return b_block_window;
793  }
794 
795  CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr,
796  const QuantGemmKernelArgs& kargs,
797  const index_t i_m,
798  const index_t i_n)
799  {
800  // Step 1: Create tensor view for BQ
801  const auto& bq_tensor_view = [&]() {
802  if constexpr(kQuantType == QuantType::RowColQuant)
803  {
804  return make_naive_tensor_view<address_space_enum::global>(
805  bq_ptr,
806  make_tuple(kargs.M, kargs.N),
807  make_tuple(0, 1), // broadcasting over m
808  number<1>{},
809  number<1>{});
810  }
811  else if constexpr(kQuantType == QuantType::BQuantGrouped)
812  {
813  if constexpr(PreshuffleQuant)
814  {
815  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
816  "PreshuffleQuant with BQuantGrouped currently only supports "
817  "ColumnMajor BQ layout");
819 
820  return MakePreshuffledQuantTensorView<
821  GemmPipeline::KPerBlockBQ,
822  GemmPipeline::NPerBlockBQ,
823  GemmPipeline::NPerBlock,
824  TilePartitioner::BlockGemmShape::WarpTile::at(I1),
825  GemmPipeline::GetVectorSizeBQ()>(
826  bq_ptr,
827  ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN),
828  QuantGroupSize::kN,
829  kargs.QK_B);
830  }
831  else
832  {
834 
835  if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
836  {
837  return make_naive_tensor_view<address_space_enum::global>(
838  bq_ptr,
839  make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
840  integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
841  make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
842  number<GemmPipeline::GetVectorSizeBQ()>{},
843  number<1>{});
844  }
845  else
846  {
847  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
848  return make_naive_tensor_view<address_space_enum::global>(
849  bq_ptr,
850  make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
851  integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
852  make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
853  number<GemmPipeline::GetVectorSizeBQ()>{},
854  number<1>{});
855  }
856  }
857  }
858  else if constexpr(kQuantType == QuantType::ABQuantGrouped)
859  {
860  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
862  return make_naive_tensor_view<address_space_enum::global>(
863  bq_ptr,
864  make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
865  make_tuple(kargs.stride_BQ, 1),
866  number<GemmPipeline::GetVectorSizeBQ()>{},
867  number<1>{});
868  }
869  else
870  {
871  return nullptr;
872  }
873  }();
874 
875  // Step 2: Create tile window (no padding for BQ)
876  const auto& bq_block_window = [&]() {
877  if constexpr(kQuantType == QuantType::RowColQuant)
878  {
879  return make_tile_window(bq_tensor_view,
882  {i_m, i_n});
883  }
884  else if constexpr(kQuantType == QuantType::BQuantGrouped)
885  {
887  if constexpr(PreshuffleQuant)
888  {
889  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
890  constexpr auto block_n =
891  TilePartitioner::NPerBlock /
892  QuantGroupSize::kN; // Number of N-dimension quantization groups per block
893  constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(
894  I1); // Number of N-dimension elements per warp
895  constexpr auto warp_per_group =
896  (QuantGroupSize::kN <
897  warp_n) // Determine how many warps share the same scale in N-dimension
898  ? (warp_n / QuantGroupSize::kN)
899  : (QuantGroupSize::kN / warp_n);
900  constexpr auto bqk_per_block =
901  TilePartitioner::KPerBlock /
902  QuantGroupSize::kK; // Number of K-dimension quantization groups per block
903  constexpr auto
904  tile_window_width = // The pre-shuffled layout flattens warp_n ×
905  // bqk_per_block scales per row, Padded up to warp_size
906  // to ensure coalesced memory access.
907  ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
908 
909  // Adapts based on fine vs coarse quantization granularity:
910  // - Fine-grained (QuantGroupSize::kN < warp_n):
911  // Multiple quant groups per warp → fewer rows needed per block.
912  // height = block_n / warp_per_group
913  //
914  // - Coarse-grained (QuantGroupSize::kN >= warp_n):
915  // Each row represents one quant group.
916  // height = block_n
917  constexpr auto tile_window_height =
918  (QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
919  auto block_n_idx =
920  i_n / TilePartitioner::NPerBlock; // Converts the global N-index (i_n) to a
921  // block index.
922 
923  return make_tile_window(
924  bq_tensor_view,
926  {block_n_idx * tile_window_height, 0});
927  }
928  else
929  {
930  if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
931  {
932  return make_tile_window(
933  bq_tensor_view,
935  number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
936  {0, i_n / QuantGroupSize::kN});
937  }
938  else
939  {
940  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
941  return make_tile_window(
942  bq_tensor_view,
944  number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
945  {i_n / QuantGroupSize::kN, 0});
946  }
947  }
948  }
949  else if constexpr(kQuantType == QuantType::ABQuantGrouped)
950  {
951  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
953  return make_tile_window(
954  bq_tensor_view,
956  number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
957  {i_n / QuantGroupSize::kN, 0});
958  }
959  else
960  {
961  return nullptr;
962  }
963  }();
964 
965  return bq_block_window;
966  }
967 
968  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
970  const QuantGemmKernelArgs& kargs,
971  const index_t i_m,
972  const index_t i_n)
973  {
974  // Step 1: Create tensor view for C
975  const auto& c_tensor_view = [&]() {
976  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
977  {
978  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
979  c_ptr,
980  make_tuple(kargs.M, kargs.N),
981  make_tuple(kargs.stride_C, 1),
982  number<EpiloguePipeline::GetVectorSizeC()>{},
983  number<1>{});
984  }
985  else
986  {
987  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
988  c_ptr,
989  make_tuple(kargs.M, kargs.N),
990  make_tuple(1, kargs.stride_C),
991  number<1>{},
992  number<1>{});
993  }
994  }();
995 
996  // Step 2: Create padded view
997  const auto& c_pad_view = [&]() {
998  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
999  {
1000  return pad_tensor_view(c_tensor_view,
1004  }
1005  else
1006  {
1007  return pad_tensor_view(c_tensor_view,
1011  }
1012  }();
1013 
1014  // Step 3: Create tile window
1015  auto c_block_window = make_tile_window(
1016  c_pad_view,
1018  {i_m, i_n});
1019 
1020  return c_block_window;
1021  }
1022 
1024  {
1025  if(kargs.k_batch != 1)
1026  {
1027  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1028  {
1029  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
1030  }
1031  return false;
1032  }
1033 
1034  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
1035  {
1036  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
1037  GemmPipeline::kPadK == false)
1038  {
1039  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1040  {
1041  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
1042  "without padding!");
1043  }
1044  return false;
1045  }
1046  if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
1047  {
1048  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1049  {
1050  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
1051  }
1052  return false;
1053  }
1054  }
1055  else
1056  {
1057  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
1058  {
1059  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1060  {
1061  CK_TILE_ERROR(
1062  "Can't support M that is not a multiple of MPerBlock without padding!");
1063  }
1064  return false;
1065  }
1066  if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
1067  {
1068  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1069  {
1070  CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
1071  }
1072  return false;
1073  }
1074  }
1075 
1076  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
1077  {
1078  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
1079  {
1080  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1081  {
1082  CK_TILE_ERROR(
1083  "Can't support N that is not a multiple of NPerBlock without padding!");
1084  }
1085  return false;
1086  }
1087  if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
1088  {
1089  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1090  {
1091  CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
1092  }
1093  return false;
1094  }
1095  }
1096  else
1097  {
1098  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
1099  GemmPipeline::kPadK == false)
1100  {
1101  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1102  {
1103  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
1104  "without padding!");
1105  }
1106  return false;
1107  }
1108  if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
1109  {
1110  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1111  {
1112  CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
1113  }
1114  return false;
1115  }
1116  }
1117 
1118  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
1119  {
1120  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
1121  {
1122  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1123  {
1124  CK_TILE_ERROR(
1125  "Can't support N that is not a multiple of NPerBlock without padding!");
1126  }
1127  return false;
1128  }
1129  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
1130  {
1131  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1132  {
1133  CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
1134  }
1135  return false;
1136  }
1137  }
1138  else
1139  {
1140  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
1141  {
1142  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1143  {
1144  CK_TILE_ERROR(
1145  "Can't support M that is not a multiple of MPerBlock without padding!");
1146  }
1147  return false;
1148  }
1149  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
1150  {
1151  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1152  {
1153  CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
1154  }
1155  return false;
1156  }
1157  }
1158  return true;
1159  }
1160 
1176  CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
1177  const BDataType* b_ptr,
1178  const AQDataType* aq_ptr,
1179  const BQDataType* bq_ptr,
1180  CDataType* c_ptr,
1181  void* smem_ptr,
1182  const QuantGemmKernelArgs& kargs,
1183  const SplitKBatchOffset& splitk_batch_offset,
1184  const index_t block_idx_m,
1185  const index_t block_idx_n)
1186  {
1187  // Create block windows using specialized methods
1188  const auto& a_block_window =
1189  MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
1190  const auto& b_block_window =
1191  MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
1192  const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
1193  const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
1194 
1195  const index_t num_loop =
1196  amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
1197 
1198  // Run GEMM cooperatively by whole workgroup.
1199  const auto& c_block_tile = [&]() {
1200  if constexpr(kQuantType == QuantType::AQuantGrouped)
1201  {
1202  index_t m = 0;
1203  if constexpr(PreshuffleQuant)
1204  {
1205  m = kargs.M;
1206  }
1207  return GemmPipeline{}.template operator()(
1208  a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr, m);
1209  }
1210  else if constexpr(kQuantType == QuantType::BQuantGrouped)
1211  {
1212  index_t n = 0;
1213  if constexpr(PreshuffleQuant)
1214  {
1215  n = kargs.N;
1216  }
1217  return GemmPipeline{}.template operator()(
1218  a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr, n);
1219  }
1220  else if constexpr(kQuantType == QuantType::ABQuantGrouped)
1221  {
1222  index_t m = 0;
1223  index_t n = 0;
1224  if constexpr(PreshuffleQuant)
1225  {
1226  m = kargs.M;
1227  n = kargs.N;
1228  }
1229  return GemmPipeline{}.template operator()(a_block_window,
1230  b_block_window,
1231  aq_block_window,
1232  bq_block_window,
1233  num_loop,
1234  smem_ptr,
1235  m,
1236  n);
1237  }
1238  else if constexpr(kQuantType == QuantType::RowColQuant ||
1240  {
1241  return GemmPipeline{}.template operator()(
1242  a_block_window, b_block_window, num_loop, smem_ptr);
1243  }
1244  }();
1245 
1246  const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
1247 
1248  // Run Epilogue Pipeline with k_batch dispatch
1249  if(k_batch == 1)
1250  {
1251  auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
1252  c_ptr, kargs, block_idx_m, block_idx_n);
1253 
1254  if constexpr(kQuantType == QuantType::ABQuantGrouped ||
1257  {
1258  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
1259  }
1260  else if constexpr(kQuantType == QuantType::RowColQuant)
1261  {
1262  EpiloguePipeline{}(c_block_window,
1263  c_block_tile,
1264  c_block_window,
1265  smem_ptr,
1266  aq_block_window,
1267  bq_block_window);
1268  }
1269  else if constexpr(kQuantType == QuantType::TensorQuant)
1270  {
1271  const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
1272  const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
1273  EpiloguePipeline{}(
1274  c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
1275  }
1276  }
1277  else
1278  {
1279  auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
1280  c_ptr, kargs, block_idx_m, block_idx_n);
1281 
1282  if constexpr(kQuantType == QuantType::ABQuantGrouped ||
1285  {
1286  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
1287  }
1288  else if constexpr(kQuantType == QuantType::RowColQuant)
1289  {
1290  EpiloguePipeline{}(c_block_window,
1291  c_block_tile,
1292  c_block_window,
1293  smem_ptr,
1294  aq_block_window,
1295  bq_block_window);
1296  }
1297  else if constexpr(kQuantType == QuantType::TensorQuant)
1298  {
1299  const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
1300  const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
1301  EpiloguePipeline{}(
1302  c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
1303  }
1304  }
1305  }
1306 
1308  {
1309  const auto blockId = amd_wave_read_first_lane(blockIdx.x);
1310  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1311  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1312  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1313  const SplitKBatchOffset splitk_batch_offset(kargs);
1314 
1315  // Apply splitk offset to input pointers
1316  const ADataType* a_ptr =
1317  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
1318  const BDataType* b_ptr =
1319  static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
1320  const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
1321  const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
1322  CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
1323 
1324  // allocate LDS
1325  __shared__ char smem_ptr[GetSmemSize()];
1326 
1327  RunGemm(
1328  a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
1329  }
1330 };
1331 
1332 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1659
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_least_multiple(X x, Y y)
Definition: math.hpp:151
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1690
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1697
QuantType
Definition: tile_gemm_quant_traits.hpp:12
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
unsigned int uint32_t
Definition: stdint.h:126
Definition: gemm_quant_kernel.hpp:133
void * c_ptr
Definition: gemm_quant_kernel.hpp:166
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:164
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:165
const void * b_ptr
Definition: gemm_quant_kernel.hpp:163
CK_TILE_HOST QuantGemmHostArgs()=default
index_t k_batch
Definition: gemm_quant_kernel.hpp:167
const void * a_ptr
Definition: gemm_quant_kernel.hpp:162
CK_TILE_HOST QuantGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, const void *aq_ptr_, const void *bq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:135
Definition: gemm_quant_kernel.hpp:365
__device__ SplitKBatchOffset(const QuantGemmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_quant_kernel.hpp:366
index_t a_k_split_offset
Definition: gemm_quant_kernel.hpp:401
index_t b_k_split_offset
Definition: gemm_quant_kernel.hpp:402
index_t splitted_k
Definition: gemm_quant_kernel.hpp:403
Definition: gemm_quant_kernel.hpp:171
index_t k_batch
Definition: gemm_quant_kernel.hpp:187
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:186
const void * b_ptr
Definition: gemm_quant_kernel.hpp:173
void * c_ptr
Definition: gemm_quant_kernel.hpp:176
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:174
index_t stride_A
Definition: gemm_quant_kernel.hpp:182
index_t M
Definition: gemm_quant_kernel.hpp:177
const void * a_ptr
Definition: gemm_quant_kernel.hpp:172
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:175
index_t QK_B
Definition: gemm_quant_kernel.hpp:181
index_t K
Definition: gemm_quant_kernel.hpp:179
index_t QK_A
Definition: gemm_quant_kernel.hpp:180
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:185
index_t N
Definition: gemm_quant_kernel.hpp:178
index_t stride_C
Definition: gemm_quant_kernel.hpp:184
index_t stride_B
Definition: gemm_quant_kernel.hpp:183
Definition: gemm_quant_kernel.hpp:195
static constexpr auto I4
Definition: gemm_quant_kernel.hpp:227
static constexpr auto I3
Definition: gemm_quant_kernel.hpp:226
static constexpr bool PreshuffleB
Definition: gemm_quant_kernel.hpp:211
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_quant_kernel.hpp:238
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_quant_kernel.hpp:1176
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_quant_kernel.hpp:197
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_quant_kernel.hpp:198
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_quant_kernel.hpp:196
remove_cvref_t< typename EpiloguePipeline::AccDataType > AccDataType
Definition: gemm_quant_kernel.hpp:216
static CK_TILE_DEVICE auto MakeCBlockWindow(CDataType *c_ptr, const QuantGemmKernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:969
static constexpr auto I0
Definition: gemm_quant_kernel.hpp:223
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
Definition: gemm_quant_kernel.hpp:1307
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_quant_kernel.hpp:215
static constexpr index_t kBlockSize
Definition: gemm_quant_kernel.hpp:208
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_quant_kernel.hpp:200
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_quant_kernel.hpp:201
static CK_TILE_DEVICE auto MakeABlockWindow(const ADataType *a_ptr, const QuantGemmKernelArgs &kargs, const index_t k_size, const index_t i_m)
Definition: gemm_quant_kernel.hpp:406
static CK_TILE_DEVICE auto MakeBQBlockWindow(const BQDataType *bq_ptr, const QuantGemmKernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:795
static constexpr auto I1
Definition: gemm_quant_kernel.hpp:224
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_quant_kernel.hpp:199
static constexpr bool PreshuffleQuant
Definition: gemm_quant_kernel.hpp:209
static CK_TILE_DEVICE auto MakeBBlockWindow(const BDataType *b_ptr, const QuantGemmKernelArgs &kargs, const index_t k_size, const index_t i_n)
Definition: gemm_quant_kernel.hpp:635
static CK_TILE_HOST bool IsSupportedArgument(const QuantGemmKernelArgs &kargs)
Definition: gemm_quant_kernel.hpp:1023
remove_cvref_t< typename detail::get_aq_data_type_or< GemmPipeline, AccDataType >::type > AQDataType
Definition: gemm_quant_kernel.hpp:219
remove_cvref_t< typename detail::get_bq_data_type_or< GemmPipeline, AccDataType >::type > BQDataType
Definition: gemm_quant_kernel.hpp:221
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_quant_kernel.hpp:214
static constexpr auto I2
Definition: gemm_quant_kernel.hpp:225
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_quant_kernel.hpp:269
static constexpr CK_TILE_HOST QuantGemmKernelArgs MakeKernelArgs(const QuantGemmHostArgs &hostArgs)
Definition: gemm_quant_kernel.hpp:249
static CK_TILE_DEVICE auto MakeAQBlockWindow(const AQDataType *aq_ptr, const QuantGemmKernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:472
static CK_TILE_HOST const std::string GetName()
Definition: gemm_quant_kernel.hpp:231
remove_cvref_t< typename detail::get_bq_layout_or< GemmPipeline, typename GemmPipeline::BLayout >::type > BQLayout
Definition: gemm_quant_kernel.hpp:206
static CK_TILE_HOST auto BlockSize()
Definition: gemm_quant_kernel.hpp:243
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_quant_kernel.hpp:213
remove_cvref_t< typename detail::get_aq_layout_or< GemmPipeline, typename GemmPipeline::ALayout >::type > AQLayout
Definition: gemm_quant_kernel.hpp:204
static constexpr auto kQuantType
Definition: gemm_quant_kernel.hpp:229
Definition: gemm_quant_kernel.hpp:95
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:128
index_t N
Definition: gemm_quant_kernel.hpp:121
index_t K
Definition: gemm_quant_kernel.hpp:122
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:129
index_t stride_C
Definition: gemm_quant_kernel.hpp:127
index_t stride_B
Definition: gemm_quant_kernel.hpp:126
index_t stride_A
Definition: gemm_quant_kernel.hpp:125
CK_TILE_HOST QuantGemmProblem(index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:97
index_t QK_A
Definition: gemm_quant_kernel.hpp:123
index_t QK_B
Definition: gemm_quant_kernel.hpp:124
CK_TILE_HOST QuantGemmProblem()=default
index_t M
Definition: gemm_quant_kernel.hpp:120
Definition: integral_constant.hpp:13
Definition: gemm_quant_kernel.hpp:47
Default type
Definition: gemm_quant_kernel.hpp:48
typename T::AQLayout type
Definition: gemm_quant_kernel.hpp:30
Definition: gemm_quant_kernel.hpp:23
Default type
Definition: gemm_quant_kernel.hpp:24
Definition: gemm_quant_kernel.hpp:59
Default type
Definition: gemm_quant_kernel.hpp:60
typename T::BQLayout type
Definition: gemm_quant_kernel.hpp:42
Definition: gemm_quant_kernel.hpp:35
Default type
Definition: gemm_quant_kernel.hpp:36
Definition: gemm_quant_kernel.hpp:83
static constexpr bool value
Definition: gemm_quant_kernel.hpp:84
Definition: gemm_quant_kernel.hpp:71
static constexpr bool value
Definition: gemm_quant_kernel.hpp:72
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145