/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp Source File
gemm_quant_pipeline_problem.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 #include <string>
11 
12 namespace ck_tile {
13 
14 template <typename ADataType_,
15  typename AQDataType_,
16  typename BDataType_,
17  typename BQDataType_,
18  typename CDataType_,
19  typename BlockGemmShape_,
20  typename Traits_,
21  typename AQuantGroupSize_,
22  typename BQuantGroupSize_,
23  bool TransposeC_,
24  typename ComputeDataType_ = BDataType_,
26  bool HasHotLoop_ = true,
27  TailNumber TailNum_ = TailNumber::Full>
29  BDataType_,
30  CDataType_,
31  BlockGemmShape_,
32  Traits_,
33  ComputeDataType_>
34 {
35  using Base = GemmPipelineProblemBase<ADataType_,
36  BDataType_,
37  CDataType_,
38  BlockGemmShape_,
39  Traits_,
40  ComputeDataType_>;
41 
42  using Traits = typename Base::Traits;
43 
44  using typename Base::ADataType;
45  using typename Base::BDataType;
46  using typename Base::CDataType;
47  using typename Base::ComputeDataType;
50 
53  std::conditional_t<!std::is_void_v<AQuantGroupSize_>, AQuantGroupSize_, BQuantGroupSize_>;
55  std::conditional_t<!std::is_void_v<BQuantGroupSize_>, BQuantGroupSize_, AQuantGroupSize_>;
56  // Unified alias for 1D quantization usage, to avoid forcing users to pick one.
58 
59  using typename Base::ALayout;
60  using typename Base::BLayout;
61  using typename Base::CLayout;
62 
63  static constexpr bool TransposeC = TransposeC_;
64  static constexpr bool PreshuffleB = Traits::PreshuffleB;
65  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
66  using Base::kBlockSize;
67 
68  using Base::kPadK;
69  using Base::kPadM;
70  using Base::kPadN;
71 
73 
76 
77  static constexpr auto Scheduler = Scheduler_;
78  static constexpr auto HasHotLoop = HasHotLoop_;
79  static constexpr auto TailNum = TailNum_;
80 
81  static_assert(BlockGemmShape::kM % AQuantGroupSize::kM == 0);
82  static_assert(BlockGemmShape::kN % AQuantGroupSize::kN == 0);
83  static_assert(BlockGemmShape::kK % AQuantGroupSize::kK == 0);
84  static_assert(BlockGemmShape::kM % BQuantGroupSize::kM == 0);
85  static_assert(BlockGemmShape::kN % BQuantGroupSize::kN == 0);
86  static_assert(BlockGemmShape::kK % BQuantGroupSize::kK == 0);
87 
88  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
89  {
90  // clang-format off
91  return concat('_', "gemm_quant_problem",
93  concat('x', kPadM, kPadN, kPadK),
94  Scheduler,
95  AQuantGroupSize::GetName(),
96  BQuantGroupSize::GetName());
97  // clang-format on
98  }
99 
100  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentAQ()
101  {
102  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
103  return VectorLoadSize / sizeof(AQDataType);
104  }
105 
106  static constexpr index_t VectorSizeAQ = []() {
107  static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
108  return kPadK ? 1 : GetAlignmentAQ();
109  }();
110 
111  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
112  {
113  return VectorLoadSize / sizeof(BQDataType);
114  }
115 
116  static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
117 };
118 
119 template <typename ADataType_,
120  typename AQDataType_,
121  typename BDataType_,
122  typename CDataType_,
123  typename BlockGemmShape_,
124  typename Traits_,
125  typename QuantGroupSize_,
126  bool TransposeC_,
127  typename ComputeDataType_ = BDataType_,
129  bool HasHotLoop_ = true,
130  TailNumber TailNum_ = TailNumber::Full>
132  AQDataType_,
133  BDataType_,
134  void, // no BQDataType for AQuant
135  CDataType_,
136  BlockGemmShape_,
137  Traits_,
138  QuantGroupSize_,
139  void,
140  TransposeC_,
141  ComputeDataType_,
142  Scheduler_,
143  HasHotLoop_,
144  TailNum_>;
145 
146 template <typename ADataType_,
147  typename BDataType_,
148  typename BQDataType_,
149  typename CDataType_,
150  typename BlockGemmShape_,
151  typename Traits_,
152  typename QuantGroupSize_,
153  typename ComputeDataType_ = ADataType_,
155  bool HasHotLoop_ = true,
156  TailNumber TailNum_ = TailNumber::Full>
158  void, // no AQDataType for BQuant
159  BDataType_,
160  BQDataType_,
161  CDataType_,
162  BlockGemmShape_,
163  Traits_,
164  void,
165  QuantGroupSize_,
166  false, // no TransposeC
167  ComputeDataType_,
168  Scheduler_,
169  HasHotLoop_,
170  TailNum_>;
171 
172 template <typename ADataType_,
173  typename AQDataType_,
174  typename BDataType_,
175  typename BQDataType_,
176  typename CDataType_,
177  typename BlockGemmShape_,
178  typename Traits_,
179  typename AQuantGroupSize_,
180  typename BQuantGroupSize_,
181  bool TransposeC_,
182  typename ComputeDataType_ = ADataType_,
184  bool HasHotLoop_ = true,
185  TailNumber TailNum_ = TailNumber::Full>
187  AQDataType_,
188  BDataType_,
189  BQDataType_,
190  CDataType_,
191  BlockGemmShape_,
192  Traits_,
193  AQuantGroupSize_,
194  BQuantGroupSize_,
195  TransposeC_,
196  ComputeDataType_,
197  Scheduler_,
198  HasHotLoop_,
199  TailNum_>;
200 
201 template <typename ADataType_,
202  typename BDataType_,
203  typename CDataType_,
204  typename AccDataType_,
205  typename BlockGemmShape_,
206  typename Traits_,
207  bool TransposeC_ = false,
208  typename ComputeDataType_ = BDataType_,
210  bool HasHotLoop_ = true,
211  TailNumber TailNum_ = TailNumber::Full>
213  GemmQuantPipelineProblemBase<ADataType_,
214  AccDataType_,
215  BDataType_,
216  AccDataType_,
217  CDataType_,
218  BlockGemmShape_,
219  Traits_,
220  void,
221  QuantGroupShape<sequence<1, 1, 1>>, // no group size applicable
222  TransposeC_,
223  ComputeDataType_,
224  Scheduler_,
225  HasHotLoop_,
226  TailNum_>;
227 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
GemmPipelineScheduler
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:14
Definition: gemm_pipeline_problem.hpp:25
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:34
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:80
remove_cvref_t< std::tuple_element_t< number< 0 >{}, ComputeDataTypeTuple > > ComputeDataType
Definition: gemm_pipeline_problem.hpp:66
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsLayoutTuple > > ALayout
Definition: gemm_pipeline_problem.hpp:68
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: gemm_pipeline_problem.hpp:69
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:76
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:78
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: gemm_pipeline_problem.hpp:67
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:41
remove_cvref_t< EDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:30
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:26
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:79
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:84
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsLayoutTuple > > BLayout
Definition: gemm_pipeline_problem.hpp:70
Definition: gemm_quant_pipeline_problem.hpp:34
static constexpr auto HasHotLoop
Definition: gemm_quant_pipeline_problem.hpp:78
typename Base::BlockGemmShape BlockGemmShape
Definition: gemm_quant_pipeline_problem.hpp:51
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:80
remove_cvref_t< BQDataType_ > BQDataType
Definition: gemm_quant_pipeline_problem.hpp:49
static CK_TILE_HOST const std::string GetName()
Definition: gemm_quant_pipeline_problem.hpp:88
static constexpr bool DoubleSmemBuffer
Definition: gemm_quant_pipeline_problem.hpp:65
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentBQ()
Definition: gemm_quant_pipeline_problem.hpp:111
remove_cvref_t< typename Traits::BQLayout > BQLayout
Definition: gemm_quant_pipeline_problem.hpp:75
remove_cvref_t< typename Traits::AQLayout > AQLayout
Definition: gemm_quant_pipeline_problem.hpp:74
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:76
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:78
static constexpr index_t VectorSizeBQ
Definition: gemm_quant_pipeline_problem.hpp:116
static constexpr auto Scheduler
Definition: gemm_quant_pipeline_problem.hpp:77
BQuantGroupSize QuantGroupSize
Definition: gemm_quant_pipeline_problem.hpp:57
static constexpr index_t VectorSizeAQ
Definition: gemm_quant_pipeline_problem.hpp:106
static constexpr bool PreshuffleB
Definition: gemm_quant_pipeline_problem.hpp:64
static constexpr bool TransposeC
Definition: gemm_quant_pipeline_problem.hpp:63
remove_cvref_t< AQDataType_ > AQDataType
Definition: gemm_quant_pipeline_problem.hpp:48
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentAQ()
Definition: gemm_quant_pipeline_problem.hpp:100
std::conditional_t<!std::is_void_v< AQuantGroupSize_ >, AQuantGroupSize_, BQuantGroupSize_ > AQuantGroupSize
Definition: gemm_quant_pipeline_problem.hpp:53
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:79
typename Base::Traits Traits
Definition: gemm_quant_pipeline_problem.hpp:42
static constexpr auto TailNum
Definition: gemm_quant_pipeline_problem.hpp:79
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:84
std::conditional_t<!std::is_void_v< BQuantGroupSize_ >, BQuantGroupSize_, AQuantGroupSize_ > BQuantGroupSize
Definition: gemm_quant_pipeline_problem.hpp:55
Definition: gemm_group_quant_utils.hpp:448