/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp Source File
gemm_quant_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 #include <string>
11 
12 namespace ck_tile {
13 
14 template <typename ADataType_,
15  typename AQDataType_,
16  typename BDataType_,
17  typename BQDataType_,
18  typename CDataType_,
19  typename BlockGemmShape_,
20  typename Traits_,
21  uint32_t QuantGroupSize_,
22  bool TransposeC_,
23  typename ComputeDataType_ = BDataType_,
25  bool HasHotLoop_ = true,
26  TailNumber TailNum_ = TailNumber::Full>
28  BDataType_,
29  CDataType_,
30  BlockGemmShape_,
31  Traits_,
32  ComputeDataType_>
33 {
34  using Base = GemmPipelineProblemBase<ADataType_,
35  BDataType_,
36  CDataType_,
37  BlockGemmShape_,
38  Traits_,
39  ComputeDataType_>;
40 
41  using Traits = typename Base::Traits;
42 
43  using typename Base::ADataType;
44  using typename Base::BDataType;
45  using typename Base::CDataType;
46  using typename Base::ComputeDataType;
49 
51 
52  using typename Base::ALayout;
53  using typename Base::BLayout;
54  using typename Base::CLayout;
55 
56  static constexpr bool TransposeC = TransposeC_;
57 
58  using Base::kBlockSize;
59 
60  using Base::kPadK;
61  using Base::kPadM;
62  using Base::kPadN;
63 
66 
69 
70  static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
71  static constexpr auto Scheduler = Scheduler_;
72  static constexpr auto HasHotLoop = HasHotLoop_;
73  static constexpr auto TailNum = TailNum_;
74 
75  static_assert(BlockGemmShape::kK % kQuantGroupSize == 0);
77 
78  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
79  {
80  // clang-format off
81  return concat('_', "gemm_quant_problem",
83  concat('x', kPadM, kPadN, kPadK),
84  Scheduler,
85  "QuantGroupSize",
87  // clang-format on
88  }
89 
90  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentAQ()
91  {
92  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
93  return VectorLoadSize / sizeof(AQDataType);
94  }
95 
96  static constexpr index_t VectorSizeAQ = []() {
97  static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
98  return kPadK ? 1 : GetAlignmentAQ();
99  }();
100 
101  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
102  {
103  return VectorLoadSize / sizeof(BQDataType);
104  }
105 
106  static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
107 };
108 
109 template <typename ADataType_,
110  typename AQDataType_,
111  typename BDataType_,
112  typename CDataType_,
113  typename BlockGemmShape_,
114  typename Traits_,
115  uint32_t QuantGroupSize_,
116  bool TransposeC_,
117  typename ComputeDataType_ = BDataType_,
119  bool HasHotLoop_ = true,
120  TailNumber TailNum_ = TailNumber::Full>
122  AQDataType_,
123  BDataType_,
124  void, // no BQDataType for AQuant
125  CDataType_,
126  BlockGemmShape_,
127  Traits_,
128  QuantGroupSize_,
129  TransposeC_,
130  ComputeDataType_,
131  Scheduler_,
132  HasHotLoop_,
133  TailNum_>;
134 
135 template <typename ADataType_,
136  typename BDataType_,
137  typename BQDataType_,
138  typename CDataType_,
139  typename BlockGemmShape_,
140  typename Traits_,
141  uint32_t QuantGroupSize_,
142  typename ComputeDataType_ = ADataType_,
144  bool HasHotLoop_ = true,
145  TailNumber TailNum_ = TailNumber::Full>
147  void, // no AQDataType for BQuant
148  BDataType_,
149  BQDataType_,
150  CDataType_,
151  BlockGemmShape_,
152  Traits_,
153  QuantGroupSize_,
154  false, // no TransposeC
155  ComputeDataType_,
156  Scheduler_,
157  HasHotLoop_,
158  TailNum_>;
159 
160 template <typename ADataType_,
161  typename BDataType_,
162  typename CDataType_,
163  typename AccDataType_,
164  typename BlockGemmShape_,
165  typename Traits_,
166  bool TransposeC_ = false,
167  typename ComputeDataType_ = BDataType_,
169  bool HasHotLoop_ = true,
170  TailNumber TailNum_ = TailNumber::Full>
172  AccDataType_,
173  BDataType_,
174  AccDataType_,
175  CDataType_,
176  BlockGemmShape_,
177  Traits_,
178  1, // no group size applicable
179  TransposeC_,
180  ComputeDataType_,
181  Scheduler_,
182  HasHotLoop_,
183  TailNum_>;
184 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
GemmPipelineScheduler
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:14
unsigned int uint32_t
Definition: stdint.h:126
Definition: gemm_pipeline_problem.hpp:22
remove_cvref_t< CDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:27
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:48
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:44
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:50
remove_cvref_t< ComputeDataType_ > ComputeDataType
Definition: gemm_pipeline_problem.hpp:28
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:23
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:32
remove_cvref_t< typename Traits::ALayout > ALayout
Definition: gemm_pipeline_problem.hpp:34
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:45
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:36
remove_cvref_t< BDataType_ > BDataType
Definition: gemm_pipeline_problem.hpp:26
remove_cvref_t< ADataType_ > ADataType
Definition: gemm_pipeline_problem.hpp:25
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:46
remove_cvref_t< typename Traits::BLayout > BLayout
Definition: gemm_pipeline_problem.hpp:35
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:42
Definition: gemm_quant_pipeline_problem.hpp:33
static constexpr uint32_t kQuantGroupSize
Definition: gemm_quant_pipeline_problem.hpp:70
remove_cvref_t< BQDataType_ > BQDataType
Definition: gemm_quant_pipeline_problem.hpp:48
remove_cvref_t< AQDataType_ > AQDataType
Definition: gemm_quant_pipeline_problem.hpp:47
remove_cvref_t< typename Traits::BQLayout > BQLayout
Definition: gemm_quant_pipeline_problem.hpp:68
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:44
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:50
static constexpr bool TransposeC
Definition: gemm_quant_pipeline_problem.hpp:56
remove_cvref_t< typename Traits::AQLayout > AQLayout
Definition: gemm_quant_pipeline_problem.hpp:67
typename Base::BlockGemmShape BlockGemmShape
Definition: gemm_quant_pipeline_problem.hpp:50
static constexpr auto Scheduler
Definition: gemm_quant_pipeline_problem.hpp:71
static constexpr index_t VectorSizeBQ
Definition: gemm_quant_pipeline_problem.hpp:106
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:45
static constexpr auto TailNum
Definition: gemm_quant_pipeline_problem.hpp:73
static CK_TILE_HOST const std::string GetName()
Definition: gemm_quant_pipeline_problem.hpp:78
typename Base::Traits Traits
Definition: gemm_quant_pipeline_problem.hpp:41
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentAQ()
Definition: gemm_quant_pipeline_problem.hpp:90
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:46
static constexpr index_t VectorSizeAQ
Definition: gemm_quant_pipeline_problem.hpp:96
static constexpr auto HasHotLoop
Definition: gemm_quant_pipeline_problem.hpp:72
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentBQ()
Definition: gemm_quant_pipeline_problem.hpp:101
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:42