/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File
gemm_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
10 
11 namespace ck_tile {
12 
13 template <typename AsDataType_,
14  typename BsDataType_,
15  typename EDataType_,
16  typename BlockGemmShape_,
17  typename Traits_,
18  typename ComputeDataType_ = AsDataType_,
19  typename AElementWise_ = ck_tile::element_wise::PassThrough,
20  typename BElementWise_ = ck_tile::element_wise::PassThrough,
21  bool FixedVectorSize_ = false,
22  index_t VectorSizeA_ = 1,
23  index_t VectorSizeB_ = 1>
25 {
27 
30  using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
31 
32  static constexpr bool FixedVectorSize = FixedVectorSize_;
33 
35 
38 
42 
46 
49 
53  using AsLayoutTuple = std::
54  conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
55  using BsLayoutTuple = std::
56  conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
57 
61 
65 
71 
72  static constexpr bool TransposeC = Traits::TransposeC;
73  static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
74  static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
75 
76  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
77 
78  static constexpr bool kPadM = Traits::kPadM;
79  static constexpr bool kPadN = Traits::kPadN;
80  static constexpr bool kPadK = Traits::kPadK;
81 
82  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
83  static constexpr auto Scheduler = GemmPipelineScheduler::Default;
84  static constexpr index_t VectorLoadSize = Traits::_VectorSize;
85 
86  // In the base situation, the Preshuffle setting should be false.
87  static constexpr bool Preshuffle = false;
88 
89  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
90  {
91  // clang-format off
92  return concat('_', "gemm_problem",
94  concat('x', kPadM, kPadN, kPadK),
95  Scheduler);
96  // clang-format on
97  }
98 
99  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
100  {
101  constexpr index_t PackedSize =
103  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
104  {
105  constexpr index_t pixels_per_thread =
106  BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
107  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
108  ? pixels_per_thread
109  : PackedSize * VectorLoadSize / sizeof(ADataType);
110  }
111  else
112  {
113  return VectorLoadSize / sizeof(ADataType);
114  }
115  }
116 
117  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
118  {
119  constexpr index_t PackedSize =
121  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
122  {
123  constexpr index_t pixels_per_thread =
124  BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
125  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
126  ? pixels_per_thread
127  : PackedSize * VectorLoadSize / sizeof(BDataType);
128  }
129  else
130  {
131  return PackedSize * VectorLoadSize / sizeof(BDataType);
132  }
133  }
134 
135  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
136  {
137  if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
138  {
139  constexpr index_t N1 = kBlockSize / get_warp_size();
140  constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
141  constexpr index_t M0 = get_warp_size() / N2;
142  constexpr index_t M1 = BlockGemmShape::kM / M0;
143 
144  return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
145  }
146  else
147  {
148  constexpr index_t M1 = kBlockSize / get_warp_size();
149  constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
150  constexpr index_t N0 = get_warp_size() / M2;
151  constexpr index_t N1 = BlockGemmShape::kN / N0;
152 
153  return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
154  }
155  }
156 
157  static constexpr index_t VectorSizeA = []() {
158  if constexpr(FixedVectorSize)
159  {
160  return VectorSizeA_;
161  }
162  else if constexpr(std::is_same_v<AsLayout, tensor_layout::gemm::RowMajor>)
163  {
164  return kPadK ? 1 : GetAlignmentA();
165  }
166  else
167  {
168  return kPadM ? 1 : GetAlignmentA();
169  }
170  }();
171 
172  static constexpr index_t VectorSizeB = []() {
173  if constexpr(FixedVectorSize)
174  {
175  return VectorSizeB_;
176  }
177  else if constexpr(std::is_same_v<BsLayout, tensor_layout::gemm::ColumnMajor>)
178  {
179  return kPadN ? 1 : GetAlignmentB();
180  }
181  else
182  {
183  return kPadK ? 1 : GetAlignmentB();
184  }
185  }();
186  static constexpr index_t VectorSizeC = []() {
187  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
188  {
189  return kPadN ? 1 : GetAlignmentC();
190  }
191  else
192  {
193  return kPadM ? 1 : GetAlignmentC();
194  }
195  }();
196 };
197 
198 template <typename AsDataType_,
199  typename BsDataType_,
200  typename EDataType_,
201  typename BlockGemmShape_,
202  typename Traits_,
203  typename AElementWise_ = ck_tile::element_wise::PassThrough,
204  typename BElementWise_ = ck_tile::element_wise::PassThrough,
205  typename ComputeDataType_ = AsDataType_,
206  bool FixedVectorSize_ = false,
207  index_t VectorSizeA_ = 1,
208  index_t VectorSizeB_ = 1>
210  BsDataType_,
211  EDataType_,
212  BlockGemmShape_,
213  Traits_,
214  ComputeDataType_,
215  AElementWise_,
216  BElementWise_,
217  FixedVectorSize_,
218  VectorSizeA_,
219  VectorSizeB_>;
220 
221 template <typename AsDataType_,
222  typename BsDataType_,
223  typename EDataType_,
224  typename BlockGemmShape_,
225  typename Traits_,
227  bool HasHotLoop_ = true,
228  TailNumber TailNum_ = TailNumber::Full,
229  typename AElementWise_ = ck_tile::element_wise::PassThrough,
230  typename BElementWise_ = ck_tile::element_wise::PassThrough,
231  typename ComputeDataType_ = AsDataType_,
232  bool FixedVectorSize_ = false,
233  index_t VectorSizeA_ = 1,
234  index_t VectorSizeB_ = 1>
236 {
238 
241  using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
244 
245  static constexpr bool FixedVectorSize = FixedVectorSize_;
246 
248 
252 
256 
259 
263  using AsLayoutTuple = std::
264  conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
265  using BsLayoutTuple = std::
266  conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
267 
271 
275 
281 
282  static constexpr bool TransposeC = Traits::TransposeC;
283  static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
284  static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
285 
286  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
287 
288  static constexpr bool kPadM = Traits::kPadM;
289  static constexpr bool kPadN = Traits::kPadN;
290  static constexpr bool kPadK = Traits::kPadK;
291 
292  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
293  static constexpr auto Scheduler = Scheduler_;
294  static constexpr bool Preshuffle = Traits::Preshuffle;
295 
296  static constexpr index_t VectorSizeA = VectorSizeA_;
297  static constexpr index_t VectorSizeB = VectorSizeB_;
298 
299  static constexpr auto HasHotLoop = HasHotLoop_;
300  static constexpr auto TailNum = TailNum_;
301  static constexpr index_t VectorLoadSize = Traits::_VectorSize;
302  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
303  {
304  // clang-format off
305  return concat('_', "gemm_problem",
306  concat('x', kBlockSize),
307  concat('x', kPadM, kPadN, kPadK),
308  Scheduler);
309  // clang-format on
310  }
311 };
312 
313 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
GemmPipelineScheduler
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:14
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: gemm_pipeline_problem.hpp:25
remove_cvref_t< BElementWise_ > BElementWise
Definition: gemm_pipeline_problem.hpp:37
static constexpr index_t NumWaveGroups
Definition: gemm_pipeline_problem.hpp:73
static constexpr bool UseStructuredSparsity
Definition: gemm_pipeline_problem.hpp:74
static constexpr bool FixedVectorSize
Definition: gemm_pipeline_problem.hpp:32
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:34
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentB()
Definition: gemm_pipeline_problem.hpp:117
static constexpr bool ComputeDataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:43
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:80
static constexpr index_t VectorSizeC
Definition: gemm_pipeline_problem.hpp:186
remove_cvref_t< BsDataType_ > BsDataType
Definition: gemm_pipeline_problem.hpp:29
static constexpr bool Preshuffle
Definition: gemm_pipeline_problem.hpp:87
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:83
std::conditional_t< BLayoutIsTuple, remove_cvref_t< BsLayout >, remove_cvref_t< tuple< BsLayout > >> BsLayoutTuple
Definition: gemm_pipeline_problem.hpp:56
remove_cvref_t< std::tuple_element_t< number< 0 >{}, ComputeDataTypeTuple > > ComputeDataType
Definition: gemm_pipeline_problem.hpp:66
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsLayoutTuple > > ALayout
Definition: gemm_pipeline_problem.hpp:68
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: gemm_pipeline_problem.hpp:69
remove_cvref_t< AElementWise_ > AElementWise
Definition: gemm_pipeline_problem.hpp:36
static constexpr bool ALayoutIsTuple
Definition: gemm_pipeline_problem.hpp:47
std::conditional_t< ALayoutIsTuple, remove_cvref_t< AsLayout >, remove_cvref_t< tuple< AsLayout > >> AsLayoutTuple
Definition: gemm_pipeline_problem.hpp:54
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:76
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:78
static constexpr index_t VectorSizeB
Definition: gemm_pipeline_problem.hpp:172
remove_cvref_t< typename Traits::AsLayout > AsLayout
Definition: gemm_pipeline_problem.hpp:39
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:72
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: gemm_pipeline_problem.hpp:67
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:82
static constexpr bool BDataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:45
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< AsDataType >, remove_cvref_t< tuple< AsDataType > >> AsDataTypeTuple
Definition: gemm_pipeline_problem.hpp:60
static constexpr bool BLayoutIsTuple
Definition: gemm_pipeline_problem.hpp:48
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< BsDataType >, remove_cvref_t< tuple< BsDataType > >> BsDataTypeTuple
Definition: gemm_pipeline_problem.hpp:64
remove_cvref_t< AsDataType_ > AsDataType
Definition: gemm_pipeline_problem.hpp:28
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:41
std::conditional_t< ComputeDataTypeIsTuple, remove_cvref_t< ComputeDataType_ >, remove_cvref_t< tuple< ComputeDataType_ > >> ComputeDataTypeTuple
Definition: gemm_pipeline_problem.hpp:52
static CK_TILE_HOST const std::string GetName()
Definition: gemm_pipeline_problem.hpp:89
remove_cvref_t< typename Traits::BsLayout > BsLayout
Definition: gemm_pipeline_problem.hpp:40
remove_cvref_t< EDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:30
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentA()
Definition: gemm_pipeline_problem.hpp:99
static constexpr index_t VectorSizeA
Definition: gemm_pipeline_problem.hpp:157
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:26
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:79
static constexpr bool ADataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:44
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:84
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsLayoutTuple > > BLayout
Definition: gemm_pipeline_problem.hpp:70
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentC()
Definition: gemm_pipeline_problem.hpp:135
Definition: gemm_pipeline_problem.hpp:236
remove_cvref_t< AElementWise_ > AElementWise
Definition: gemm_pipeline_problem.hpp:242
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:247
remove_cvref_t< EDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:241
remove_cvref_t< typename Traits::BsLayout > BsLayout
Definition: gemm_pipeline_problem.hpp:250
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:289
static constexpr bool BLayoutIsTuple
Definition: gemm_pipeline_problem.hpp:258
static constexpr index_t VectorSizeB
Definition: gemm_pipeline_problem.hpp:297
static constexpr auto TailNum
Definition: gemm_pipeline_problem.hpp:300
remove_cvref_t< AsDataType_ > AsDataType
Definition: gemm_pipeline_problem.hpp:239
remove_cvref_t< BElementWise_ > BElementWise
Definition: gemm_pipeline_problem.hpp:243
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:286
static constexpr auto HasHotLoop
Definition: gemm_pipeline_problem.hpp:299
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< BsDataType >, remove_cvref_t< tuple< BsDataType > >> BsDataTypeTuple
Definition: gemm_pipeline_problem.hpp:274
static constexpr bool BDataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:255
std::conditional_t< BLayoutIsTuple, remove_cvref_t< BsLayout >, remove_cvref_t< tuple< BsLayout > >> BsLayoutTuple
Definition: gemm_pipeline_problem.hpp:266
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:301
remove_cvref_t< typename Traits::AsLayout > AsLayout
Definition: gemm_pipeline_problem.hpp:249
static constexpr bool ADataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:254
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:290
static constexpr index_t VectorSizeA
Definition: gemm_pipeline_problem.hpp:296
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:237
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: gemm_pipeline_problem.hpp:277
static constexpr bool UseStructuredSparsity
Definition: gemm_pipeline_problem.hpp:284
remove_cvref_t< BsDataType_ > BsDataType
Definition: gemm_pipeline_problem.hpp:240
static CK_TILE_HOST const std::string GetName()
Definition: gemm_pipeline_problem.hpp:302
static constexpr bool ALayoutIsTuple
Definition: gemm_pipeline_problem.hpp:257
remove_cvref_t< std::tuple_element_t< number< 0 >{}, ComputeDataTypeTuple > > ComputeDataType
Definition: gemm_pipeline_problem.hpp:276
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:282
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:288
static constexpr bool ComputeDataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:253
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:251
static constexpr index_t NumWaveGroups
Definition: gemm_pipeline_problem.hpp:283
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:293
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: gemm_pipeline_problem.hpp:279
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:292
std::conditional_t< ComputeDataTypeIsTuple, remove_cvref_t< ComputeDataType_ >, remove_cvref_t< tuple< ComputeDataType_ > >> ComputeDataTypeTuple
Definition: gemm_pipeline_problem.hpp:262
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsLayoutTuple > > ALayout
Definition: gemm_pipeline_problem.hpp:278
std::conditional_t< ALayoutIsTuple, remove_cvref_t< AsLayout >, remove_cvref_t< tuple< AsLayout > >> AsLayoutTuple
Definition: gemm_pipeline_problem.hpp:264
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< AsDataType >, remove_cvref_t< tuple< AsDataType > >> AsDataTypeTuple
Definition: gemm_pipeline_problem.hpp:270
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsLayoutTuple > > BLayout
Definition: gemm_pipeline_problem.hpp:280
static constexpr bool Preshuffle
Definition: gemm_pipeline_problem.hpp:294
static constexpr bool FixedVectorSize
Definition: gemm_pipeline_problem.hpp:245
Definition: unary_element_wise_operation.hpp:431
Definition: numeric.hpp:81