/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File
gemm_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 template <typename ADataType_,
13  typename BDataType_,
14  typename CDataType_,
15  typename BlockGemmShape_,
16  typename Traits_,
17  typename ComputeDataType_ = ADataType_,
18  bool FixedVectorSize_ = false,
19  index_t VectorSizeA_ = 1,
20  index_t VectorSizeB_ = 1>
22 {
24 
27  using CDataType = remove_cvref_t<CDataType_>; // actually AccDataType
29 
30  static constexpr bool FixedVectorSize = FixedVectorSize_;
31 
33 
37 
38  static constexpr bool TransposeC = Traits::TransposeC;
39  static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
40  static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
41 
42  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
43 
44  static constexpr bool kPadM = Traits::kPadM;
45  static constexpr bool kPadN = Traits::kPadN;
46  static constexpr bool kPadK = Traits::kPadK;
47 
48  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
49  static constexpr auto Scheduler = GemmPipelineScheduler::Default;
50  static constexpr index_t VectorLoadSize = Traits::_VectorSize;
51 
52  // In the base situation, the Preshuffle setting should be false.
53  static constexpr bool Preshuffle = false;
54 
55  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
56  {
57  // clang-format off
58  return concat('_', "gemm_problem",
60  concat('x', kPadM, kPadN, kPadK),
61  Scheduler);
62  // clang-format on
63  }
64 
65  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
66  {
67  constexpr index_t PackedSize =
69  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
70  {
71  constexpr index_t pixels_per_thread =
72  BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
73  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
74  ? pixels_per_thread
75  : PackedSize * VectorLoadSize / sizeof(ADataType);
76  }
77  else
78  {
79  return VectorLoadSize / sizeof(ADataType);
80  }
81  }
82 
83  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
84  {
85  constexpr index_t PackedSize =
87  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
88  {
89  constexpr index_t pixels_per_thread =
90  BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
91  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
92  ? pixels_per_thread
93  : PackedSize * VectorLoadSize / sizeof(BDataType);
94  }
95  else
96  {
97  return PackedSize * VectorLoadSize / sizeof(BDataType);
98  }
99  }
100 
101  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
102  {
103  if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
104  {
105  constexpr index_t N1 = kBlockSize / get_warp_size();
106  constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
107  constexpr index_t M0 = get_warp_size() / N2;
108  constexpr index_t M1 = BlockGemmShape::kM / M0;
109 
110  return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
111  }
112  else
113  {
114  constexpr index_t M1 = kBlockSize / get_warp_size();
115  constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
116  constexpr index_t N0 = get_warp_size() / M2;
117  constexpr index_t N1 = BlockGemmShape::kN / N0;
118 
119  return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
120  }
121  }
122 
123  static constexpr index_t VectorSizeA = []() {
124  if constexpr(FixedVectorSize)
125  {
126  return VectorSizeA_;
127  }
128  else if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
129  {
130  return kPadK ? 1 : GetAlignmentA();
131  }
132  else
133  {
134  return kPadM ? 1 : GetAlignmentA();
135  }
136  }();
137 
138  static constexpr index_t VectorSizeB = []() {
139  if constexpr(FixedVectorSize)
140  {
141  return VectorSizeB_;
142  }
143  else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
144  {
145  return kPadN ? 1 : GetAlignmentB();
146  }
147  else
148  {
149  return kPadK ? 1 : GetAlignmentB();
150  }
151  }();
152  static constexpr index_t VectorSizeC = []() {
153  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
154  {
155  return kPadN ? 1 : GetAlignmentC();
156  }
157  else
158  {
159  return kPadM ? 1 : GetAlignmentC();
160  }
161  }();
162 };
163 
164 // Alias for GemmPipelineProblem
165 template <typename ADataType_,
166  typename BDataType_,
167  typename CDataType_,
168  typename BlockGemmShape_,
169  typename Traits_,
170  typename ComputeDataType_ = ADataType_,
171  bool FixedVectorSize_ = false,
172  index_t VectorSizeA_ = 1,
173  index_t VectorSizeB_ = 1>
175  BDataType_,
176  CDataType_,
177  BlockGemmShape_,
178  Traits_,
179  ComputeDataType_,
180  FixedVectorSize_,
181  VectorSizeA_,
182  VectorSizeB_>;
183 
184 template <typename ADataType_,
185  typename BDataType_,
186  typename CDataType_,
187  typename BlockGemmShape_,
188  typename Traits_,
190  bool HasHotLoop_ = true,
191  TailNumber TailNum_ = TailNumber::Full,
192  typename ComputeDataType_ = ADataType_,
193  bool FixedVectorSize_ = false,
194  index_t VectorSizeA_ = 1,
195  index_t VectorSizeB_ = 1>
197 {
199 
202  using CDataType = remove_cvref_t<CDataType_>; // actually AccDataType
204 
205  static constexpr bool FixedVectorSize = FixedVectorSize_;
206 
208 
212 
213  static constexpr bool TransposeC = Traits::TransposeC;
214  static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
215  static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
216 
217  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
218 
219  static constexpr bool kPadM = Traits::kPadM;
220  static constexpr bool kPadN = Traits::kPadN;
221  static constexpr bool kPadK = Traits::kPadK;
222 
223  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
224  static constexpr auto Scheduler = Scheduler_;
225  static constexpr bool Preshuffle = Traits::Preshuffle;
226 
227  static constexpr index_t VectorSizeA = VectorSizeA_;
228  static constexpr index_t VectorSizeB = VectorSizeB_;
229 
230  static constexpr auto HasHotLoop = HasHotLoop_;
231  static constexpr auto TailNum = TailNum_;
232  static constexpr index_t VectorLoadSize = Traits::_VectorSize;
233  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
234  {
235  // clang-format off
236  return concat('_', "gemm_problem",
237  concat('x', kBlockSize),
238  concat('x', kPadM, kPadN, kPadK),
239  Scheduler);
240  // clang-format on
241  }
242 };
243 
244 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
GemmPipelineScheduler
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:14
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
Definition: gemm_pipeline_problem.hpp:22
static constexpr index_t VectorSizeA
Definition: gemm_pipeline_problem.hpp:123
static constexpr index_t VectorSizeB
Definition: gemm_pipeline_problem.hpp:138
remove_cvref_t< CDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:27
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:48
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:44
static constexpr bool Preshuffle
Definition: gemm_pipeline_problem.hpp:53
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:50
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:49
remove_cvref_t< ComputeDataType_ > ComputeDataType
Definition: gemm_pipeline_problem.hpp:28
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:38
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:23
static constexpr index_t NumWaveGroups
Definition: gemm_pipeline_problem.hpp:39
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:32
remove_cvref_t< typename Traits::ALayout > ALayout
Definition: gemm_pipeline_problem.hpp:34
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentB()
Definition: gemm_pipeline_problem.hpp:83
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentC()
Definition: gemm_pipeline_problem.hpp:101
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentA()
Definition: gemm_pipeline_problem.hpp:65
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:45
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:36
remove_cvref_t< BDataType_ > BDataType
Definition: gemm_pipeline_problem.hpp:26
static CK_TILE_HOST const std::string GetName()
Definition: gemm_pipeline_problem.hpp:55
remove_cvref_t< ADataType_ > ADataType
Definition: gemm_pipeline_problem.hpp:25
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:46
remove_cvref_t< typename Traits::BLayout > BLayout
Definition: gemm_pipeline_problem.hpp:35
static constexpr bool FixedVectorSize
Definition: gemm_pipeline_problem.hpp:30
static constexpr bool UseStructuredSparsity
Definition: gemm_pipeline_problem.hpp:40
static constexpr index_t VectorSizeC
Definition: gemm_pipeline_problem.hpp:152
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:42
Definition: gemm_pipeline_problem.hpp:197
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:213
static constexpr index_t VectorSizeB
Definition: gemm_pipeline_problem.hpp:228
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:221
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:232
static constexpr index_t VectorSizeA
Definition: gemm_pipeline_problem.hpp:227
remove_cvref_t< typename Traits::ALayout > ALayout
Definition: gemm_pipeline_problem.hpp:209
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:223
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:198
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:211
static constexpr bool UseStructuredSparsity
Definition: gemm_pipeline_problem.hpp:215
remove_cvref_t< BDataType_ > BDataType
Definition: gemm_pipeline_problem.hpp:201
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:220
remove_cvref_t< ADataType_ > ADataType
Definition: gemm_pipeline_problem.hpp:200
static constexpr bool Preshuffle
Definition: gemm_pipeline_problem.hpp:225
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:224
remove_cvref_t< typename Traits::BLayout > BLayout
Definition: gemm_pipeline_problem.hpp:210
remove_cvref_t< ComputeDataType_ > ComputeDataType
Definition: gemm_pipeline_problem.hpp:203
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:219
static CK_TILE_HOST const std::string GetName()
Definition: gemm_pipeline_problem.hpp:233
remove_cvref_t< CDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:202
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:207
static constexpr bool FixedVectorSize
Definition: gemm_pipeline_problem.hpp:205
static constexpr index_t NumWaveGroups
Definition: gemm_pipeline_problem.hpp:214
static constexpr auto TailNum
Definition: gemm_pipeline_problem.hpp:231
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:217
static constexpr auto HasHotLoop
Definition: gemm_pipeline_problem.hpp:230
Definition: numeric.hpp:81