/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File
block_fmha_pipeline_problem.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 template <typename QDataType_,
13  typename KDataType_,
14  typename VDataType_,
15  typename SaccDataType_,
16  typename SMPLComputeDataType_,
17  typename BiasDataType_,
18  typename RandValOutputDataType_,
19  typename LSEDataType_,
20  typename PDataType_,
21  typename OaccDataType_,
22  typename ODataType_,
23  typename BlockFmhaShape_,
24  bool kIsGroupMode_,
25  typename AttentionVariant_,
26  typename FmhaMask_,
27  bool kUseTrLoad_,
28  typename Traits_>
30 {
46 
47  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
48  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
49  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
50 
51  static constexpr bool kIsGroupMode = kIsGroupMode_;
52  static constexpr bool kUseTrLoad = kUseTrLoad_;
53 
54  // attributes from traits
55  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
56  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
57  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
58  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
59  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
60  static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
61  static constexpr auto BiasEnum = Traits::BiasEnum;
62  static constexpr bool kStoreLSE = Traits::kStoreLSE;
63  static constexpr bool kHasDropout = Traits::kHasDropout;
64  static constexpr auto QScaleEnum = Traits::QScaleEnum;
65  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
66  static constexpr bool kHasSink = Traits::kHasSink;
67 };
68 
69 template <typename QDataType_,
70  typename KDataType_,
71  typename VDataType_,
72  typename SaccDataType_,
73  typename SMPLComputeDataType_,
74  typename BiasDataType_,
75  typename RandValOutputDataType_,
76  typename LSEDataType_,
77  typename PDataType_,
78  typename OaccDataType_,
79  typename ODataType_,
80  typename BlockFmhaShape_,
81  bool kIsGroupMode_,
82  typename AttentionVariant_,
83  typename FmhaMask_,
84  bool kUseTrLoad_,
85  int kPageBlockSize_,
86  typename Traits_>
88  : public BlockFmhaPipelineProblem<QDataType_,
89  KDataType_,
90  VDataType_,
91  SaccDataType_,
92  SMPLComputeDataType_,
93  BiasDataType_,
94  RandValOutputDataType_,
95  LSEDataType_,
96  PDataType_,
97  OaccDataType_,
98  ODataType_,
99  BlockFmhaShape_,
100  kIsGroupMode_,
101  AttentionVariant_,
102  FmhaMask_,
103  kUseTrLoad_,
104  Traits_>
105 {
106  static constexpr index_t kPageBlockSize = kPageBlockSize_;
107  static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive");
108  static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
109  "kPageBlockSize must be power of two");
110  static constexpr index_t kLog2PageSize = []() constexpr {
111  index_t shift = 0;
112  index_t val = kPageBlockSize_;
113  while(val > 1)
114  {
115  val >>= 1;
116  shift++;
117  }
118  return shift;
119  }();
120 
121  static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
122  static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
123  static constexpr auto kKVLookupTable = Traits_::kKVLookupTable;
124  static constexpr bool kIsVectorizedLayout =
126 
127  static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0,
128  "kQKHeaddim must be divisible by kVectorSize");
129  static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0,
130  "kPageBlockSize must be divisible by kVectorSize for vectorized layout");
131  static_assert(kIsGroupMode_, "Batch prefill requires group mode");
132 };
133 
134 template <typename QDataType_,
135  typename KDataType_,
136  typename VDataType_,
137  typename SaccDataType_,
138  typename SMPLComputeDataType_,
139  typename BiasDataType_,
140  typename LSEDataType_,
141  typename PDataType_,
142  typename OaccDataType_,
143  typename ODataType_,
144  typename BlockFmhaShape_,
145  bool kIsGroupMode_,
146  typename AttentionVariant_,
147  typename FmhaMask_,
148  typename Traits_>
150 {
165 
166  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
167  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
168  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
169 
170  static constexpr bool kIsGroupMode = kIsGroupMode_;
171 
172  // attributes from traits
173  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
174  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
175  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
176  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
177  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
178  static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
179  static constexpr auto BiasEnum = Traits::BiasEnum;
180  static constexpr bool kStoreLSE = Traits::kStoreLSE;
181  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
182  static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
183  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
184  static constexpr bool kHasSink = Traits::kHasSink;
185 };
186 
187 template <typename QDataType_,
188  typename KDataType_,
189  typename VDataType_,
190  typename SaccDataType_,
191  typename SMPLComputeDataType_,
192  typename BiasDataType_,
193  typename LSEDataType_,
194  typename PDataType_,
195  typename OaccDataType_,
196  typename ODataType_,
197  typename BlockFmhaShape_,
198  bool kIsGroupMode_,
199  typename AttentionVariant_,
200  typename FmhaMask_,
201  typename Traits_>
203 {
218 
219  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
220  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
221  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
222 
223  static constexpr bool kIsGroupMode = kIsGroupMode_;
224 
225  // attributes from traits
226  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
227  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
228  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
229  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
230  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
231  static constexpr auto BiasEnum = Traits::BiasEnum;
232  static constexpr bool kStoreLSE = Traits::kStoreLSE;
233  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
234  static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
235  static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
236  static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
237  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
238  static constexpr bool kHasSink = Traits::kHasSink;
239 };
240 
241 // extract tile size attributes to remove dependency on traits
242 template <typename OaccDataType_, ck_tile::index_t kN1_>
244 {
245  static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_);
246 
247  static constexpr index_t kN1 = kN1_;
248  static constexpr index_t NThreads = kN1 / MaxVectorSize;
249  static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp
250 };
251 
252 template <typename LSEDataType_,
253  typename OaccDataType_,
254  typename ODataType_,
255  index_t HeadDimV_,
256  bool kIsGroupMode_,
257  ck_tile::index_t kN1_,
258  typename Traits_>
260  : BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>
261 {
263 
268 
269  static_assert(std::is_same_v<LSEDataType, OaccDataType>);
270 
271  static constexpr index_t kHeadDimV = HeadDimV_;
272  static constexpr bool kIsGroupMode = kIsGroupMode_;
273 
274  using BaseType::kM0;
275  using BaseType::kN1;
276  using BaseType::NThreads;
277 
278  static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
279 
280  // attributes from traits
281  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
282  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
283  static constexpr bool kStoreLSE = Traits::kStoreLSE;
284  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
285  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
286  static constexpr index_t kMaxSplits = Traits::kMaxSplits;
287  static_assert(8 <= kMaxSplits);
288 
289  static constexpr index_t kNumWarps = 4;
290  static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
291 
292  static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
293  (kM0 * kMaxSplits) % get_warp_size() == 0);
294 };
295 
296 template <typename QDataType_,
297  typename KDataType_,
298  typename VDataType_,
299  index_t kM0_,
300  index_t kN0_,
301  index_t kK0_,
302  index_t kN1_,
303  bool kIsVLayoutRowMajor_,
304  RotaryEmbeddingEnum RotaryEnum_,
305  bool kIsPagedKV_,
306  typename Traits_>
308 {
313 
314  static constexpr index_t kBlockSize = 256;
315 
316  static constexpr index_t kM0 = kM0_;
317  static constexpr index_t kN0 = kN0_;
318  static constexpr index_t kK0 = kK0_;
319  static constexpr index_t kN1 = kN1_;
320 
321  using VLayout = std::conditional_t<kIsVLayoutRowMajor_,
324 
325  static constexpr auto RotaryEnum = RotaryEnum_;
326  static constexpr bool kIsPagedKV = kIsPagedKV_;
327 
328  // attributes from traits
329  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
330  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
331  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
332  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
333  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
334 };
335 
336 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
RotaryEmbeddingEnum
Definition: block_rotary_embedding.hpp:12
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: block_fmha_pipeline_problem.hpp:105
static constexpr index_t kVectorSize
Definition: block_fmha_pipeline_problem.hpp:121
static constexpr auto kKVLookupTable
Definition: block_fmha_pipeline_problem.hpp:123
static constexpr auto kKVMemoryLayout
Definition: block_fmha_pipeline_problem.hpp:122
static constexpr index_t kLog2PageSize
Definition: block_fmha_pipeline_problem.hpp:110
static constexpr bool kIsVectorizedLayout
Definition: block_fmha_pipeline_problem.hpp:124
static constexpr index_t kPageBlockSize
Definition: block_fmha_pipeline_problem.hpp:106
Definition: block_fmha_pipeline_problem.hpp:308
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:309
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:330
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:331
std::conditional_t< kIsVLayoutRowMajor_, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: block_fmha_pipeline_problem.hpp:323
static constexpr auto RotaryEnum
Definition: block_fmha_pipeline_problem.hpp:325
static constexpr index_t kK0
Definition: block_fmha_pipeline_problem.hpp:318
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:329
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:312
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:326
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:311
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:316
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:319
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:333
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:314
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:332
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:310
static constexpr index_t kN0
Definition: block_fmha_pipeline_problem.hpp:317
Definition: block_fmha_pipeline_problem.hpp:150
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:155
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:175
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:181
static constexpr bool kHasSink
Definition: block_fmha_pipeline_problem.hpp:184
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:164
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:183
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:161
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:176
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:153
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:162
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:163
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:158
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:154
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:152
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:179
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:151
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:156
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:159
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:177
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:173
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:168
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:174
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:160
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:166
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:170
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:157
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:182
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:180
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:178
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:167
Definition: block_fmha_pipeline_problem.hpp:203
static constexpr bool kHasUnevenSplits
Definition: block_fmha_pipeline_problem.hpp:235
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:206
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:230
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:216
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:228
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:233
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:219
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:204
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:212
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:210
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:223
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:236
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:220
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:207
static constexpr bool kHasSink
Definition: block_fmha_pipeline_problem.hpp:238
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:234
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:208
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:214
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:205
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:226
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:221
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:237
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:211
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:213
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:231
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:215
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:227
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:232
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:209
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:229
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:217
Definition: block_fmha_pipeline_problem.hpp:30
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:56
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:43
static constexpr bool kHasDropout
Definition: block_fmha_pipeline_problem.hpp:63
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:62
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:59
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:61
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:42
static constexpr auto QScaleEnum
Definition: block_fmha_pipeline_problem.hpp:64
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:40
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:60
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:47
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:45
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:57
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:34
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:38
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:51
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:32
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:65
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition: block_fmha_pipeline_problem.hpp:37
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:41
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:49
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:58
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:39
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:33
static constexpr bool kUseTrLoad
Definition: block_fmha_pipeline_problem.hpp:52
static constexpr bool kHasSink
Definition: block_fmha_pipeline_problem.hpp:66
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:35
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:55
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:36
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:44
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:48
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:31
Definition: block_fmha_pipeline_problem.hpp:261
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:266
static constexpr index_t kNumWarps
Definition: block_fmha_pipeline_problem.hpp:289
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:267
static constexpr index_t kHeadDimV
Definition: block_fmha_pipeline_problem.hpp:271
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:290
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:249
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:285
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:272
static constexpr index_t kMaxSplits
Definition: block_fmha_pipeline_problem.hpp:286
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:282
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:283
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:284
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:264
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:281
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:247
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:265
Definition: block_fmha_pipeline_problem.hpp:244
static constexpr index_t NThreads
Definition: block_fmha_pipeline_problem.hpp:248
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:249
static constexpr index_t MaxVectorSize
Definition: block_fmha_pipeline_problem.hpp:245
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:247
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17