/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File
gemm_pipeline_problem.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
10 
11 namespace ck_tile {
12 
13 template <typename AsDataType_,
14  typename BsDataType_,
15  typename EDataType_,
16  typename BlockGemmShape_,
17  typename Traits_,
18  typename ComputeDataType_ = AsDataType_,
19  typename AElementWise_ = ck_tile::element_wise::PassThrough,
20  typename BElementWise_ = ck_tile::element_wise::PassThrough,
21  bool FixedVectorSize_ = false,
22  index_t VectorSizeA_ = 1,
23  index_t VectorSizeB_ = 1>
25 {
27 
30  using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
31 
32  static constexpr bool FixedVectorSize = FixedVectorSize_;
33 
35 
38 
42 
46 
49 
53  using AsLayoutTuple = std::
54  conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
55  using BsLayoutTuple = std::
56  conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
57 
61 
65 
71 
72  static constexpr bool TransposeC = Traits::TransposeC;
73  static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
74  static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
75 
76  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
77 
78  static constexpr bool kPadM = Traits::kPadM;
79  static constexpr bool kPadN = Traits::kPadN;
80  static constexpr bool kPadK = Traits::kPadK;
81 
82  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
83  static constexpr auto Scheduler = GemmPipelineScheduler::Default;
84  static constexpr index_t VectorLoadSize = Traits::_VectorSize;
85 
86  // In the base situation, the Preshuffle setting should be false.
87  static constexpr bool Preshuffle = false;
88 
89  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
90  {
91  // clang-format off
92  return concat('_', "gemm_problem",
94  concat('x', kPadM, kPadN, kPadK),
95  Scheduler);
96  // clang-format on
97  }
98 
99  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
100  {
101  constexpr index_t PackedSize =
103  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
104  {
105  constexpr index_t pixels_per_thread =
106  BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
107  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
108  ? pixels_per_thread
109  : PackedSize * VectorLoadSize / sizeof(ADataType);
110  }
111  else
112  {
113  return VectorLoadSize / sizeof(ADataType);
114  }
115  }
116 
117  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
118  {
119  constexpr index_t PackedSize =
121  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
122  {
123  constexpr index_t pixels_per_thread =
124  BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
125  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
126  ? pixels_per_thread
127  : PackedSize * VectorLoadSize / sizeof(BDataType);
128  }
129  else
130  {
131  return PackedSize * VectorLoadSize / sizeof(BDataType);
132  }
133  }
134 
135  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
136  {
137  if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
138  {
139  constexpr index_t N1 = kBlockSize / get_warp_size();
140  constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
141  constexpr index_t M0 = get_warp_size() / N2;
142  constexpr index_t M1 = BlockGemmShape::kM / M0;
143 
144  return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
145  }
146  else
147  {
148  constexpr index_t M1 = kBlockSize / get_warp_size();
149  constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
150  constexpr index_t N0 = get_warp_size() / M2;
151  constexpr index_t N1 = BlockGemmShape::kN / N0;
152 
153  return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
154  }
155  }
156 
157  static constexpr index_t VectorSizeA = []() {
158  if constexpr(FixedVectorSize)
159  {
160  return VectorSizeA_;
161  }
162  else if constexpr(std::is_same_v<AsLayout, tensor_layout::gemm::RowMajor>)
163  {
164  return kPadK ? 1 : GetAlignmentA();
165  }
166  else
167  {
168  return kPadM ? 1 : GetAlignmentA();
169  }
170  }();
171 
172  static constexpr index_t VectorSizeB = []() {
173  if constexpr(FixedVectorSize)
174  {
175  return VectorSizeB_;
176  }
177  else if constexpr(std::is_same_v<BsLayout, tensor_layout::gemm::ColumnMajor>)
178  {
179  return kPadN ? 1 : GetAlignmentB();
180  }
181  else
182  {
183  return kPadK ? 1 : GetAlignmentB();
184  }
185  }();
186  static constexpr index_t VectorSizeC = []() {
187  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
188  {
189  return kPadN ? 1 : GetAlignmentC();
190  }
191  else
192  {
193  return kPadM ? 1 : GetAlignmentC();
194  }
195  }();
196 };
197 
198 template <typename AsDataType_,
199  typename BsDataType_,
200  typename EDataType_,
201  typename BlockGemmShape_,
202  typename Traits_,
203  typename AElementWise_ = ck_tile::element_wise::PassThrough,
204  typename BElementWise_ = ck_tile::element_wise::PassThrough,
205  typename ComputeDataType_ = AsDataType_,
206  bool FixedVectorSize_ = false,
207  index_t VectorSizeA_ = 1,
208  index_t VectorSizeB_ = 1>
210  BsDataType_,
211  EDataType_,
212  BlockGemmShape_,
213  Traits_,
214  ComputeDataType_,
215  AElementWise_,
216  BElementWise_,
217  FixedVectorSize_,
218  VectorSizeA_,
219  VectorSizeB_>;
220 
221 template <typename AsDataType_,
222  typename BsDataType_,
223  typename EDataType_,
224  typename BlockGemmShape_,
225  typename Traits_,
227  typename AElementWise_ = ck_tile::element_wise::PassThrough,
228  typename BElementWise_ = ck_tile::element_wise::PassThrough,
229  typename ComputeDataType_ = AsDataType_,
230  bool FixedVectorSize_ = false,
231  index_t VectorSizeA_ = 1,
232  index_t VectorSizeB_ = 1>
234 {
236 
239  using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
242 
243  static constexpr bool FixedVectorSize = FixedVectorSize_;
244 
246 
250 
254 
257 
261  using AsLayoutTuple = std::
262  conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
263  using BsLayoutTuple = std::
264  conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
265 
269 
273 
279 
280  static constexpr bool TransposeC = Traits::TransposeC;
281  static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
282  static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
283 
284  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
285 
286  static constexpr bool kPadM = Traits::kPadM;
287  static constexpr bool kPadN = Traits::kPadN;
288  static constexpr bool kPadK = Traits::kPadK;
289 
290  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
291  static constexpr auto Scheduler = Scheduler_;
292  static constexpr bool Preshuffle = Traits::Preshuffle;
293 
294  static constexpr index_t VectorSizeA = VectorSizeA_;
295  static constexpr index_t VectorSizeB = VectorSizeB_;
296 
297  static constexpr index_t VectorLoadSize = Traits::_VectorSize;
298  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
299  {
300  // clang-format off
301  return concat('_', "gemm_problem",
302  concat('x', kBlockSize),
303  concat('x', kPadM, kPadN, kPadK),
304  Scheduler,
305  "NumWaveGroups",
307  "DoubleSmemBuffer",
309  );
310  // clang-format on
311  }
312 };
313 
314 template <typename ADataType_,
315  typename BDataType_,
316  typename CDataType_,
317  typename BlockGemmShape_,
318  typename Traits_,
320  bool HasHotLoop_ = true,
321  TailNumber TailNum_ = TailNumber::Full,
323  bool BPreShufflePermute_ = false,
324  typename ComputeDataType_ = ADataType_>
326 {
328 
333 
335 
339 
340  static constexpr bool TransposeC = Traits::TransposeC;
341  static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
342  static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
343 
344  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
345 
346  static constexpr bool kPadM = Traits::kPadM;
347  static constexpr bool kPadN = Traits::kPadN;
348  static constexpr bool kPadK = Traits::kPadK;
349 
350  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
351 
352  static constexpr auto Scheduler = GemmPipelineScheduler::Default;
353  static constexpr index_t VectorLoadSize = Traits::_VectorSize;
354 
355  static constexpr auto HasHotLoop = HasHotLoop_;
356  static constexpr auto TailNum = TailNum_;
357 
358  static constexpr auto BMemNTType = BMemNTType_;
359  static constexpr bool BPreShufflePermute = BPreShufflePermute_;
360 
361  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
362  {
363  // clang-format off
364  return concat('_', "gemm_problem",
366  concat('x', kPadM, kPadN, kPadK),
367  Scheduler);
368  // clang-format on
369  }
370 
371  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
372  {
373  constexpr index_t PackedSize =
375  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
376  {
377  constexpr index_t pixels_per_thread =
378  BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
379  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
380  ? pixels_per_thread
381  : PackedSize * VectorLoadSize / sizeof(ADataType);
382  }
383  else
384  {
385  return VectorLoadSize / sizeof(ADataType);
386  }
387  }
388 
389  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
390  {
391  constexpr index_t PackedSize =
393  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
394  {
395  constexpr index_t pixels_per_thread =
396  BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
397  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
398  ? pixels_per_thread
399  : PackedSize * VectorLoadSize / sizeof(BDataType);
400  }
401  else
402  {
403  return PackedSize * VectorLoadSize / sizeof(BDataType);
404  }
405  }
406 
407  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
408  {
409  if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
410  {
411  constexpr index_t N1 = kBlockSize / get_warp_size();
412  constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
413  constexpr index_t M0 = get_warp_size() / N2;
414  constexpr index_t M1 = BlockGemmShape::kM / M0;
415 
416  return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
417  }
418  else
419  {
420  constexpr index_t M1 = kBlockSize / get_warp_size();
421  constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
422  constexpr index_t N0 = get_warp_size() / M2;
423  constexpr index_t N1 = BlockGemmShape::kN / N0;
424 
425  return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
426  }
427  }
428 
429  static constexpr index_t VectorSizeA = []() {
430  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
431  {
432  return kPadK ? 1 : GetAlignmentA();
433  }
434  else
435  {
436  return kPadM ? 1 : GetAlignmentA();
437  }
438  }();
439 
440  static constexpr index_t VectorSizeB = []() {
441  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
442  {
443  return kPadN ? 1 : GetAlignmentB();
444  }
445  else
446  {
447  return kPadK ? 1 : GetAlignmentB();
448  }
449  }();
450  static constexpr index_t VectorSizeC = []() {
451  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
452  {
453  return kPadN ? 1 : GetAlignmentC();
454  }
455  else
456  {
457  return kPadM ? 1 : GetAlignmentC();
458  }
459  }();
460 };
461 
462 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
amd_buffer_coherence_enum
Definition: amd_buffer_coherence.hpp:14
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:54
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
GemmPipelineScheduler
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:14
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: gemm_pipeline_problem.hpp:326
static constexpr bool UseStructuredSparsity
Definition: gemm_pipeline_problem.hpp:342
static constexpr index_t VectorSizeC
Definition: gemm_pipeline_problem.hpp:450
remove_cvref_t< typename Traits::BsLayout > BLayout
Definition: gemm_pipeline_problem.hpp:337
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:334
static constexpr bool BPreShufflePermute
Definition: gemm_pipeline_problem.hpp:359
remove_cvref_t< typename Traits::AsLayout > ALayout
Definition: gemm_pipeline_problem.hpp:336
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:353
remove_cvref_t< BDataType_ > BDataType
Definition: gemm_pipeline_problem.hpp:330
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:340
static CK_TILE_HOST const std::string GetName()
Definition: gemm_pipeline_problem.hpp:361
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:350
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentA()
Definition: gemm_pipeline_problem.hpp:371
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentB()
Definition: gemm_pipeline_problem.hpp:389
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:344
static constexpr auto BMemNTType
Definition: gemm_pipeline_problem.hpp:358
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:338
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:346
remove_cvref_t< ADataType_ > ADataType
Definition: gemm_pipeline_problem.hpp:329
remove_cvref_t< CDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:331
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:347
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:348
static constexpr index_t NumWaveGroups
Definition: gemm_pipeline_problem.hpp:341
static constexpr index_t VectorSizeB
Definition: gemm_pipeline_problem.hpp:440
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:327
static constexpr auto HasHotLoop
Definition: gemm_pipeline_problem.hpp:355
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentC()
Definition: gemm_pipeline_problem.hpp:407
remove_cvref_t< ComputeDataType_ > ComputeDataType
Definition: gemm_pipeline_problem.hpp:332
static constexpr index_t VectorSizeA
Definition: gemm_pipeline_problem.hpp:429
static constexpr auto TailNum
Definition: gemm_pipeline_problem.hpp:356
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:352
Definition: gemm_pipeline_problem.hpp:25
remove_cvref_t< BElementWise_ > BElementWise
Definition: gemm_pipeline_problem.hpp:37
static constexpr index_t NumWaveGroups
Definition: gemm_pipeline_problem.hpp:73
static constexpr bool UseStructuredSparsity
Definition: gemm_pipeline_problem.hpp:74
static constexpr bool FixedVectorSize
Definition: gemm_pipeline_problem.hpp:32
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:34
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentB()
Definition: gemm_pipeline_problem.hpp:117
static constexpr bool ComputeDataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:43
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:80
static constexpr index_t VectorSizeC
Definition: gemm_pipeline_problem.hpp:186
remove_cvref_t< BsDataType_ > BsDataType
Definition: gemm_pipeline_problem.hpp:29
static constexpr bool Preshuffle
Definition: gemm_pipeline_problem.hpp:87
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:83
std::conditional_t< BLayoutIsTuple, remove_cvref_t< BsLayout >, remove_cvref_t< tuple< BsLayout > >> BsLayoutTuple
Definition: gemm_pipeline_problem.hpp:56
remove_cvref_t< std::tuple_element_t< number< 0 >{}, ComputeDataTypeTuple > > ComputeDataType
Definition: gemm_pipeline_problem.hpp:66
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsLayoutTuple > > ALayout
Definition: gemm_pipeline_problem.hpp:68
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: gemm_pipeline_problem.hpp:69
remove_cvref_t< AElementWise_ > AElementWise
Definition: gemm_pipeline_problem.hpp:36
static constexpr bool ALayoutIsTuple
Definition: gemm_pipeline_problem.hpp:47
std::conditional_t< ALayoutIsTuple, remove_cvref_t< AsLayout >, remove_cvref_t< tuple< AsLayout > >> AsLayoutTuple
Definition: gemm_pipeline_problem.hpp:54
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:76
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:78
static constexpr index_t VectorSizeB
Definition: gemm_pipeline_problem.hpp:172
remove_cvref_t< typename Traits::AsLayout > AsLayout
Definition: gemm_pipeline_problem.hpp:39
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:72
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: gemm_pipeline_problem.hpp:67
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:82
static constexpr bool BDataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:45
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< AsDataType >, remove_cvref_t< tuple< AsDataType > >> AsDataTypeTuple
Definition: gemm_pipeline_problem.hpp:60
static constexpr bool BLayoutIsTuple
Definition: gemm_pipeline_problem.hpp:48
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< BsDataType >, remove_cvref_t< tuple< BsDataType > >> BsDataTypeTuple
Definition: gemm_pipeline_problem.hpp:64
remove_cvref_t< AsDataType_ > AsDataType
Definition: gemm_pipeline_problem.hpp:28
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:41
std::conditional_t< ComputeDataTypeIsTuple, remove_cvref_t< ComputeDataType_ >, remove_cvref_t< tuple< ComputeDataType_ > >> ComputeDataTypeTuple
Definition: gemm_pipeline_problem.hpp:52
static CK_TILE_HOST const std::string GetName()
Definition: gemm_pipeline_problem.hpp:89
remove_cvref_t< typename Traits::BsLayout > BsLayout
Definition: gemm_pipeline_problem.hpp:40
remove_cvref_t< EDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:30
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentA()
Definition: gemm_pipeline_problem.hpp:99
static constexpr index_t VectorSizeA
Definition: gemm_pipeline_problem.hpp:157
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:26
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:79
static constexpr bool ADataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:44
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:84
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsLayoutTuple > > BLayout
Definition: gemm_pipeline_problem.hpp:70
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentC()
Definition: gemm_pipeline_problem.hpp:135
Definition: gemm_pipeline_problem.hpp:234
static constexpr bool FixedVectorSize
Definition: gemm_pipeline_problem.hpp:243
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:288
static CK_TILE_HOST const std::string GetName()
Definition: gemm_pipeline_problem.hpp:298
remove_cvref_t< AsDataType_ > AsDataType
Definition: gemm_pipeline_problem.hpp:237
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:290
std::conditional_t< ComputeDataTypeIsTuple, remove_cvref_t< ComputeDataType_ >, remove_cvref_t< tuple< ComputeDataType_ > >> ComputeDataTypeTuple
Definition: gemm_pipeline_problem.hpp:260
remove_cvref_t< typename Traits::AsLayout > AsLayout
Definition: gemm_pipeline_problem.hpp:247
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:280
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:284
std::conditional_t< BLayoutIsTuple, remove_cvref_t< BsLayout >, remove_cvref_t< tuple< BsLayout > >> BsLayoutTuple
Definition: gemm_pipeline_problem.hpp:264
remove_cvref_t< AElementWise_ > AElementWise
Definition: gemm_pipeline_problem.hpp:240
static constexpr bool ADataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:252
static constexpr index_t VectorSizeA
Definition: gemm_pipeline_problem.hpp:294
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:235
std::conditional_t< ALayoutIsTuple, remove_cvref_t< AsLayout >, remove_cvref_t< tuple< AsLayout > >> AsLayoutTuple
Definition: gemm_pipeline_problem.hpp:262
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< BsDataType >, remove_cvref_t< tuple< BsDataType > >> BsDataTypeTuple
Definition: gemm_pipeline_problem.hpp:272
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:286
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:249
static constexpr bool BLayoutIsTuple
Definition: gemm_pipeline_problem.hpp:256
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:291
remove_cvref_t< typename Traits::BsLayout > BsLayout
Definition: gemm_pipeline_problem.hpp:248
static constexpr bool Preshuffle
Definition: gemm_pipeline_problem.hpp:292
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< AsDataType >, remove_cvref_t< tuple< AsDataType > >> AsDataTypeTuple
Definition: gemm_pipeline_problem.hpp:268
static constexpr index_t NumWaveGroups
Definition: gemm_pipeline_problem.hpp:281
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:287
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:245
static constexpr bool UseStructuredSparsity
Definition: gemm_pipeline_problem.hpp:282
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:297
remove_cvref_t< EDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:239
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsLayoutTuple > > BLayout
Definition: gemm_pipeline_problem.hpp:278
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsLayoutTuple > > ALayout
Definition: gemm_pipeline_problem.hpp:276
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: gemm_pipeline_problem.hpp:277
static constexpr bool ComputeDataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:251
remove_cvref_t< BElementWise_ > BElementWise
Definition: gemm_pipeline_problem.hpp:241
static constexpr bool ALayoutIsTuple
Definition: gemm_pipeline_problem.hpp:255
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: gemm_pipeline_problem.hpp:275
static constexpr index_t VectorSizeB
Definition: gemm_pipeline_problem.hpp:295
remove_cvref_t< BsDataType_ > BsDataType
Definition: gemm_pipeline_problem.hpp:238
static constexpr bool BDataTypeIsTuple
Definition: gemm_pipeline_problem.hpp:253
remove_cvref_t< std::tuple_element_t< number< 0 >{}, ComputeDataTypeTuple > > ComputeDataType
Definition: gemm_pipeline_problem.hpp:274
Definition: numeric.hpp:81