/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File
gemm_pipeline_problem.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
10 
11 namespace ck_tile {
12 
13 template <typename AsDataType_,
14  typename BsDataType_,
15  typename EDataType_,
16  typename BlockGemmShape_,
17  typename Traits_,
18  typename ComputeDataType_ = AsDataType_,
19  typename AElementWise_ = ck_tile::element_wise::PassThrough,
20  typename BElementWise_ = ck_tile::element_wise::PassThrough,
21  bool FixedVectorSize_ = false,
22  index_t VectorSizeA_ = 1,
23  index_t VectorSizeB_ = 1>
25 {
27 
30  using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
31 
32  static constexpr bool FixedVectorSize = FixedVectorSize_;
33 
35 
38 
42 
46 
49 
53  using AsLayoutTuple = std::
54  conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
55  using BsLayoutTuple = std::
56  conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
57 
61 
65 
71 
72  static constexpr bool TransposeC = Traits::TransposeC;
73  static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
74  static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
75 
76  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
77 
78  static constexpr bool kPadM = Traits::kPadM;
79  static constexpr bool kPadN = Traits::kPadN;
80  static constexpr bool kPadK = Traits::kPadK;
81 
82  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
83  static constexpr auto Scheduler = GemmPipelineScheduler::Default;
84  static constexpr index_t VectorLoadSize = Traits::_VectorSize;
85 
86  // In the base situation, the Preshuffle setting should be false.
87  static constexpr bool Preshuffle = false;
88 
89  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
90  {
91  // clang-format off
92  return concat('_', "gemm_problem",
94  concat('x', kPadM, kPadN, kPadK),
95  Scheduler);
96  // clang-format on
97  }
98 
99  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
100  {
101  constexpr index_t PackedSize =
103  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
104  {
105  constexpr index_t pixels_per_thread =
106  BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
107  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
108  ? pixels_per_thread
109  : PackedSize * VectorLoadSize / sizeof(ADataType);
110  }
111  else
112  {
113  return VectorLoadSize / sizeof(ADataType);
114  }
115  }
116 
117  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
118  {
119  constexpr index_t PackedSize =
121  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
122  {
123  constexpr index_t pixels_per_thread =
124  BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
125  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
126  ? pixels_per_thread
127  : PackedSize * VectorLoadSize / sizeof(BDataType);
128  }
129  else
130  {
131  return PackedSize * VectorLoadSize / sizeof(BDataType);
132  }
133  }
134 
135  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
136  {
137  if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
138  {
139  constexpr index_t N1 = kBlockSize / get_warp_size();
140  constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
141  constexpr index_t M0 = get_warp_size() / N2;
142  constexpr index_t M1 = BlockGemmShape::kM / M0;
143 
144  return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
145  }
146  else
147  {
148  constexpr index_t M1 = kBlockSize / get_warp_size();
149  constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
150  constexpr index_t N0 = get_warp_size() / M2;
151  constexpr index_t N1 = BlockGemmShape::kN / N0;
152 
153  return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
154  }
155  }
156 
157  static constexpr index_t VectorSizeA = []() {
158  if constexpr(FixedVectorSize)
159  {
160  return VectorSizeA_;
161  }
162  else if constexpr(std::is_same_v<AsLayout, tensor_layout::gemm::RowMajor>)
163  {
164  return kPadK ? 1 : GetAlignmentA();
165  }
166  else
167  {
168  return kPadM ? 1 : GetAlignmentA();
169  }
170  }();
171 
172  static constexpr index_t VectorSizeB = []() {
173  if constexpr(FixedVectorSize)
174  {
175  return VectorSizeB_;
176  }
177  else if constexpr(std::is_same_v<BsLayout, tensor_layout::gemm::ColumnMajor>)
178  {
179  return kPadN ? 1 : GetAlignmentB();
180  }
181  else
182  {
183  return kPadK ? 1 : GetAlignmentB();
184  }
185  }();
186  static constexpr index_t VectorSizeC = []() {
187  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
188  {
189  return kPadN ? 1 : GetAlignmentC();
190  }
191  else
192  {
193  return kPadM ? 1 : GetAlignmentC();
194  }
195  }();
196 };
197 
198 template <typename AsDataType_,
199  typename BsDataType_,
200  typename EDataType_,
201  typename BlockGemmShape_,
202  typename Traits_,
203  typename AElementWise_ = ck_tile::element_wise::PassThrough,
204  typename BElementWise_ = ck_tile::element_wise::PassThrough,
205  typename ComputeDataType_ = AsDataType_,
206  bool FixedVectorSize_ = false,
207  index_t VectorSizeA_ = 1,
208  index_t VectorSizeB_ = 1>
210  BsDataType_,
211  EDataType_,
212  BlockGemmShape_,
213  Traits_,
214  ComputeDataType_,
215  AElementWise_,
216  BElementWise_,
217  FixedVectorSize_,
218  VectorSizeA_,
219  VectorSizeB_>;
220 
221 template <typename AsDataType_,
222  typename BsDataType_,
223  typename EDataType_,
224  typename BlockGemmShape_,
225  typename Traits_,
227  typename AElementWise_ = ck_tile::element_wise::PassThrough,
228  typename BElementWise_ = ck_tile::element_wise::PassThrough,
229  typename ComputeDataType_ = AsDataType_,
230  bool FixedVectorSize_ = false,
231  index_t VectorSizeA_ = 1,
232  index_t VectorSizeB_ = 1>
234 {
236 
239  using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
242 
243  static constexpr bool FixedVectorSize = FixedVectorSize_;
244 
246 
250 
254 
257 
261  using AsLayoutTuple = std::
262  conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
263  using BsLayoutTuple = std::
264  conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
265 
269 
273 
279 
280  static constexpr bool TransposeC = Traits::TransposeC;
281  static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
282  static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
283 
284  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
285 
286  static constexpr bool kPadM = Traits::kPadM;
287  static constexpr bool kPadN = Traits::kPadN;
288  static constexpr bool kPadK = Traits::kPadK;
289 
290  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
291  static constexpr auto Scheduler = Scheduler_;
292  static constexpr bool Preshuffle = Traits::Preshuffle;
293 
294  static constexpr index_t VectorSizeA = VectorSizeA_;
295  static constexpr index_t VectorSizeB = VectorSizeB_;
296 
297  static constexpr index_t VectorLoadSize = Traits::_VectorSize;
298  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
299  {
300  // clang-format off
301  return concat('_', "gemm_problem",
302  concat('x', kBlockSize),
303  concat('x', kPadM, kPadN, kPadK),
304  Scheduler,
305  "NumWaveGroups",
307  "DoubleSmemBuffer",
309  );
310  // clang-format on
311  }
312 };
313 
314 template <typename ADataType_,
315  typename BDataType_,
316  typename CDataType_,
317  typename BlockGemmShape_,
318  typename Traits_,
320  bool HasHotLoop_ = true,
321  TailNumber TailNum_ = TailNumber::Full,
322  typename ComputeDataType_ = ADataType_>
324 {
326 
331 
333 
337 
338  static constexpr bool TransposeC = Traits::TransposeC;
339  static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
340  static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
341 
342  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
343 
344  static constexpr bool kPadM = Traits::kPadM;
345  static constexpr bool kPadN = Traits::kPadN;
346  static constexpr bool kPadK = Traits::kPadK;
347 
348  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
349 
350  static constexpr auto Scheduler = GemmPipelineScheduler::Default;
351  static constexpr index_t VectorLoadSize = Traits::_VectorSize;
352 
353  static constexpr auto HasHotLoop = HasHotLoop_;
354  static constexpr auto TailNum = TailNum_;
355 
356  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
357  {
358  // clang-format off
359  return concat('_', "gemm_problem",
361  concat('x', kPadM, kPadN, kPadK),
362  Scheduler);
363  // clang-format on
364  }
365 
366  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
367  {
368  constexpr index_t PackedSize =
370  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
371  {
372  constexpr index_t pixels_per_thread =
373  BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
374  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
375  ? pixels_per_thread
376  : PackedSize * VectorLoadSize / sizeof(ADataType);
377  }
378  else
379  {
380  return VectorLoadSize / sizeof(ADataType);
381  }
382  }
383 
384  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
385  {
386  constexpr index_t PackedSize =
388  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
389  {
390  constexpr index_t pixels_per_thread =
391  BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
392  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
393  ? pixels_per_thread
394  : PackedSize * VectorLoadSize / sizeof(BDataType);
395  }
396  else
397  {
398  return PackedSize * VectorLoadSize / sizeof(BDataType);
399  }
400  }
401 
402  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
403  {
404  if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
405  {
406  constexpr index_t N1 = kBlockSize / get_warp_size();
407  constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
408  constexpr index_t M0 = get_warp_size() / N2;
409  constexpr index_t M1 = BlockGemmShape::kM / M0;
410 
411  return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
412  }
413  else
414  {
415  constexpr index_t M1 = kBlockSize / get_warp_size();
416  constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
417  constexpr index_t N0 = get_warp_size() / M2;
418  constexpr index_t N1 = BlockGemmShape::kN / N0;
419 
420  return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
421  }
422  }
423 
424  static constexpr index_t VectorSizeA = []() {
425  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
426  {
427  return kPadK ? 1 : GetAlignmentA();
428  }
429  else
430  {
431  return kPadM ? 1 : GetAlignmentA();
432  }
433  }();
434 
435  static constexpr index_t VectorSizeB = []() {
436  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
437  {
438  return kPadN ? 1 : GetAlignmentB();
439  }
440  else
441  {
442  return kPadK ? 1 : GetAlignmentB();
443  }
444  }();
445  static constexpr index_t VectorSizeC = []() {
446  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
447  {
448  return kPadN ? 1 : GetAlignmentC();
449  }
450  else
451  {
452  return kPadM ? 1 : GetAlignmentC();
453  }
454  }();
455 };
456 
457 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:54
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
GemmPipelineScheduler
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:14
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: gemm_pipeline_problem.hpp:324
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:345
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentC()
Definition: gemm_pipeline_problem.hpp:402
static constexpr index_t VectorSizeA
Definition: gemm_pipeline_problem.hpp:424
static constexpr bool UseStructuredSparsity
Definition: gemm_pipeline_problem.hpp:340
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentA()
Definition: gemm_pipeline_problem.hpp:366
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:350
static CK_TILE_HOST const std::string GetName()
Definition: gemm_pipeline_problem.hpp:356
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:325
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:348
static constexpr auto HasHotLoop
Definition: gemm_pipeline_problem.hpp:353
static constexpr index_t NumWaveGroups
Definition: gemm_pipeline_problem.hpp:339
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:332
remove_cvref_t< typename Traits::AsLayout > ALayout
Definition: gemm_pipeline_problem.hpp:334
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:351
remove_cvref_t< ComputeDataType_ > ComputeDataType
Definition: gemm_pipeline_problem.hpp:330
static constexpr index_t VectorSizeB
Definition: gemm_pipeline_problem.hpp:435
static constexpr auto TailNum
Definition: gemm_pipeline_problem.hpp:354
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:336
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:344
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentB()
Definition: gemm_pipeline_problem.hpp:384
static constexpr index_t VectorSizeC
Definition: gemm_pipeline_problem.hpp:445
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:338
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:346
remove_cvref_t< CDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:329
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:342
remove_cvref_t< BDataType_ > BDataType
Definition: gemm_pipeline_problem.hpp:328
remove_cvref_t< ADataType_ > ADataType
Definition: gemm_pipeline_problem.hpp:327
remove_cvref_t< typename Traits::BsLayout > BLayout
Definition: gemm_pipeline_problem.hpp:335
Definition: gemm_pipeline_problem.hpp:25
remove_cvref_t< BElementWise_ > BElementWise
Definition: gemm_pipeline_problem.hpp:37
static constexpr index_t NumWaveGroups
Definition: gemm_pipeline_problem.hpp:73
static constexpr bool UseStructuredSparsity
Definition: gemm_pipeline_problem.hpp:74
static constexpr bool FixedVectorSize
Definition: gemm_pipeline_problem.hpp:32
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:34
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentB()
Definition: gemm_pipeline_problem.hpp:117
static constexpr bool ComputeDataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:43
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:80
static constexpr index_t VectorSizeC
Definition: gemm_pipeline_problem.hpp:186
remove_cvref_t< BsDataType_ > BsDataType
Definition: gemm_pipeline_problem.hpp:29
static constexpr bool Preshuffle
Definition: gemm_pipeline_problem.hpp:87
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:83
std::conditional_t< BLayoutIsTuple, remove_cvref_t< BsLayout >, remove_cvref_t< tuple< BsLayout > >> BsLayoutTuple
Definition: gemm_pipeline_problem.hpp:56
remove_cvref_t< std::tuple_element_t< number< 0 >{}, ComputeDataTypeTuple > > ComputeDataType
Definition: gemm_pipeline_problem.hpp:66
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsLayoutTuple > > ALayout
Definition: gemm_pipeline_problem.hpp:68
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: gemm_pipeline_problem.hpp:69
remove_cvref_t< AElementWise_ > AElementWise
Definition: gemm_pipeline_problem.hpp:36
static constexpr bool ALayoutIsTuple
Definition: gemm_pipeline_problem.hpp:47
std::conditional_t< ALayoutIsTuple, remove_cvref_t< AsLayout >, remove_cvref_t< tuple< AsLayout > >> AsLayoutTuple
Definition: gemm_pipeline_problem.hpp:54
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:76
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:78
static constexpr index_t VectorSizeB
Definition: gemm_pipeline_problem.hpp:172
remove_cvref_t< typename Traits::AsLayout > AsLayout
Definition: gemm_pipeline_problem.hpp:39
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:72
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: gemm_pipeline_problem.hpp:67
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:82
static constexpr bool BDataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:45
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< AsDataType >, remove_cvref_t< tuple< AsDataType > >> AsDataTypeTuple
Definition: gemm_pipeline_problem.hpp:60
static constexpr bool BLayoutIsTuple
Definition: gemm_pipeline_problem.hpp:48
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< BsDataType >, remove_cvref_t< tuple< BsDataType > >> BsDataTypeTuple
Definition: gemm_pipeline_problem.hpp:64
remove_cvref_t< AsDataType_ > AsDataType
Definition: gemm_pipeline_problem.hpp:28
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:41
std::conditional_t< ComputeDataTypeIsTuple, remove_cvref_t< ComputeDataType_ >, remove_cvref_t< tuple< ComputeDataType_ > >> ComputeDataTypeTuple
Definition: gemm_pipeline_problem.hpp:52
static CK_TILE_HOST const std::string GetName()
Definition: gemm_pipeline_problem.hpp:89
remove_cvref_t< typename Traits::BsLayout > BsLayout
Definition: gemm_pipeline_problem.hpp:40
remove_cvref_t< EDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:30
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentA()
Definition: gemm_pipeline_problem.hpp:99
static constexpr index_t VectorSizeA
Definition: gemm_pipeline_problem.hpp:157
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:26
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:79
static constexpr bool ADataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:44
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:84
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsLayoutTuple > > BLayout
Definition: gemm_pipeline_problem.hpp:70
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentC()
Definition: gemm_pipeline_problem.hpp:135
Definition: gemm_pipeline_problem.hpp:234
static constexpr bool FixedVectorSize
Definition: gemm_pipeline_problem.hpp:243
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:288
static CK_TILE_HOST const std::string GetName()
Definition: gemm_pipeline_problem.hpp:298
remove_cvref_t< AsDataType_ > AsDataType
Definition: gemm_pipeline_problem.hpp:237
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:290
std::conditional_t< ComputeDataTypeIsTuple, remove_cvref_t< ComputeDataType_ >, remove_cvref_t< tuple< ComputeDataType_ > >> ComputeDataTypeTuple
Definition: gemm_pipeline_problem.hpp:260
remove_cvref_t< typename Traits::AsLayout > AsLayout
Definition: gemm_pipeline_problem.hpp:247
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:280
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:284
std::conditional_t< BLayoutIsTuple, remove_cvref_t< BsLayout >, remove_cvref_t< tuple< BsLayout > >> BsLayoutTuple
Definition: gemm_pipeline_problem.hpp:264
remove_cvref_t< AElementWise_ > AElementWise
Definition: gemm_pipeline_problem.hpp:240
static constexpr bool ADataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:252
static constexpr index_t VectorSizeA
Definition: gemm_pipeline_problem.hpp:294
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:235
std::conditional_t< ALayoutIsTuple, remove_cvref_t< AsLayout >, remove_cvref_t< tuple< AsLayout > >> AsLayoutTuple
Definition: gemm_pipeline_problem.hpp:262
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< BsDataType >, remove_cvref_t< tuple< BsDataType > >> BsDataTypeTuple
Definition: gemm_pipeline_problem.hpp:272
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:286
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:249
static constexpr bool BLayoutIsTuple
Definition: gemm_pipeline_problem.hpp:256
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:291
remove_cvref_t< typename Traits::BsLayout > BsLayout
Definition: gemm_pipeline_problem.hpp:248
static constexpr bool Preshuffle
Definition: gemm_pipeline_problem.hpp:292
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< AsDataType >, remove_cvref_t< tuple< AsDataType > >> AsDataTypeTuple
Definition: gemm_pipeline_problem.hpp:268
static constexpr index_t NumWaveGroups
Definition: gemm_pipeline_problem.hpp:281
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:287
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:245
static constexpr bool UseStructuredSparsity
Definition: gemm_pipeline_problem.hpp:282
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:297
remove_cvref_t< EDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:239
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsLayoutTuple > > BLayout
Definition: gemm_pipeline_problem.hpp:278
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsLayoutTuple > > ALayout
Definition: gemm_pipeline_problem.hpp:276
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: gemm_pipeline_problem.hpp:277
static constexpr bool ComputeDataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:251
remove_cvref_t< BElementWise_ > BElementWise
Definition: gemm_pipeline_problem.hpp:241
static constexpr bool ALayoutIsTuple
Definition: gemm_pipeline_problem.hpp:255
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: gemm_pipeline_problem.hpp:275
static constexpr index_t VectorSizeB
Definition: gemm_pipeline_problem.hpp:295
remove_cvref_t< BsDataType_ > BsDataType
Definition: gemm_pipeline_problem.hpp:238
static constexpr bool BDataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:253
remove_cvref_t< std::tuple_element_t< number< 0 >{}, ComputeDataTypeTuple > > ComputeDataType
Definition: gemm_pipeline_problem.hpp:274
Definition: numeric.hpp:81