/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File
block_fmha_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 template <typename QDataType_,
12  typename KDataType_,
13  typename VDataType_,
14  typename SaccDataType_,
15  typename SMPLComputeDataType_,
16  typename BiasDataType_,
17  typename RandValOutputDataType_,
18  typename LSEDataType_,
19  typename PDataType_,
20  typename OaccDataType_,
21  typename ODataType_,
22  typename BlockFmhaShape_,
23  bool kIsGroupMode_,
24  typename AttentionVariant_,
25  typename FmhaMask_,
26  bool kUseTrLoad_,
27  typename Traits_>
29 {
45 
46  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
47  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
48  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
49 
50  static constexpr bool kIsGroupMode = kIsGroupMode_;
51  static constexpr bool kUseTrLoad = kUseTrLoad_;
52 
53  // attributes from traits
54  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
55  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
56  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
57  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
58  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
59  static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
60  static constexpr auto BiasEnum = Traits::BiasEnum;
61  static constexpr bool kStoreLSE = Traits::kStoreLSE;
62  static constexpr bool kHasDropout = Traits::kHasDropout;
63  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
64  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
65 };
66 
67 template <typename QDataType_,
68  typename KDataType_,
69  typename VDataType_,
70  typename SaccDataType_,
71  typename SMPLComputeDataType_,
72  typename BiasDataType_,
73  typename LSEDataType_,
74  typename PDataType_,
75  typename OaccDataType_,
76  typename ODataType_,
77  typename BlockFmhaShape_,
78  bool kIsGroupMode_,
79  typename AttentionVariant_,
80  typename FmhaMask_,
81  typename Traits_>
83 {
98 
99  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
100  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
101  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
102 
103  static constexpr bool kIsGroupMode = kIsGroupMode_;
104 
105  // attributes from traits
106  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
107  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
108  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
109  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
110  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
111  static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
112  static constexpr auto BiasEnum = Traits::BiasEnum;
113  static constexpr bool kStoreLSE = Traits::kStoreLSE;
114  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
115  static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
116  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
117 };
118 
119 template <typename QDataType_,
120  typename KDataType_,
121  typename VDataType_,
122  typename SaccDataType_,
123  typename SMPLComputeDataType_,
124  typename BiasDataType_,
125  typename LSEDataType_,
126  typename PDataType_,
127  typename OaccDataType_,
128  typename ODataType_,
129  typename BlockFmhaShape_,
130  bool kIsGroupMode_,
131  typename AttentionVariant_,
132  typename FmhaMask_,
133  typename Traits_>
135 {
150 
151  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
152  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
153  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
154 
155  static constexpr bool kIsGroupMode = kIsGroupMode_;
156 
157  // attributes from traits
158  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
159  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
160  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
161  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
162  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
163  static constexpr auto BiasEnum = Traits::BiasEnum;
164  static constexpr bool kStoreLSE = Traits::kStoreLSE;
165  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
166  static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
167  static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
168  static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
169  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
170 };
171 
172 // extract tile size attributes to remove dependency on traits
173 template <typename OaccDataType_, ck_tile::index_t kN1_>
175 {
176  static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_);
177 
178  static constexpr index_t kN1 = kN1_;
179  static constexpr index_t NThreads = kN1 / MaxVectorSize;
180  static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp
181 };
182 
183 template <typename LSEDataType_,
184  typename OaccDataType_,
185  typename ODataType_,
186  index_t HeadDimV_,
187  bool kIsGroupMode_,
188  ck_tile::index_t kN1_,
189  typename Traits_>
191  : BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>
192 {
194 
199 
200  static_assert(std::is_same_v<LSEDataType, OaccDataType>);
201 
202  static constexpr index_t kHeadDimV = HeadDimV_;
203  static constexpr bool kIsGroupMode = kIsGroupMode_;
204 
205  using BaseType::kM0;
206  using BaseType::kN1;
207  using BaseType::NThreads;
208 
209  static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
210 
211  // attributes from traits
212  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
213  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
214  static constexpr bool kStoreLSE = Traits::kStoreLSE;
215  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
216  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
217  static constexpr index_t kMaxSplits = Traits::kMaxSplits;
218  static_assert(8 <= kMaxSplits);
219 
220  static constexpr index_t kNumWarps = 4;
221  static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
222 
223  static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
224  (kM0 * kMaxSplits) % get_warp_size() == 0);
225 };
226 
227 template <typename QDataType_,
228  typename KDataType_,
229  typename VDataType_,
230  index_t kM0_,
231  index_t kN0_,
232  index_t kK0_,
233  index_t kN1_,
234  bool kIsVLayoutRowMajor_,
235  RotaryEmbeddingEnum RotaryEnum_,
236  bool kIsPagedKV_,
237  typename Traits_>
239 {
244 
245  static constexpr index_t kBlockSize = 256;
246 
247  static constexpr index_t kM0 = kM0_;
248  static constexpr index_t kN0 = kN0_;
249  static constexpr index_t kK0 = kK0_;
250  static constexpr index_t kN1 = kN1_;
251 
252  using VLayout = std::conditional_t<kIsVLayoutRowMajor_,
255 
256  static constexpr auto RotaryEnum = RotaryEnum_;
257  static constexpr bool kIsPagedKV = kIsPagedKV_;
258 
259  // attributes from traits
260  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
261  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
262  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
263  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
264  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
265 };
266 
267 template <typename QDataType_,
268  typename KDataType_,
269  typename VDataType_,
270  typename SaccDataType_,
271  typename SMPLComputeDataType_,
272  typename LSEDataType_,
273  typename PDataType_,
274  typename OaccDataType_,
275  typename ODataType_,
276  typename BlockFmhaShape_,
277  bool kIsGroupMode_,
278  typename FmhaMask_,
279  typename Traits_>
281 {
294 
295  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
296  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
297  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
298 
299  static constexpr bool kIsGroupMode = kIsGroupMode_;
300 
301  // attributes from traits
302  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
303  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
304  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
305  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
306  static constexpr bool kStoreLSE = Traits::kStoreLSE;
307  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
308 };
309 
310 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
RotaryEmbeddingEnum
Definition: block_rotary_embedding.hpp:12
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: block_fmha_pipeline_problem.hpp:239
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:240
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:261
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:262
std::conditional_t< kIsVLayoutRowMajor_, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: block_fmha_pipeline_problem.hpp:254
static constexpr auto RotaryEnum
Definition: block_fmha_pipeline_problem.hpp:256
static constexpr index_t kK0
Definition: block_fmha_pipeline_problem.hpp:249
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:260
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:243
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:257
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:242
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:247
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:250
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:264
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:245
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:263
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:241
static constexpr index_t kN0
Definition: block_fmha_pipeline_problem.hpp:248
Definition: block_fmha_pipeline_problem.hpp:83
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:88
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:108
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:114
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:97
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:116
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:94
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:109
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:86
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:95
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:96
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:91
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:87
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:85
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:112
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:84
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:89
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:92
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:110
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:106
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:101
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:107
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:93
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:99
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:103
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:90
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:115
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:113
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:111
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:100
Definition: block_fmha_pipeline_problem.hpp:135
static constexpr bool kHasUnevenSplits
Definition: block_fmha_pipeline_problem.hpp:167
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:138
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:162
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:148
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:160
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:165
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:151
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:136
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:144
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:142
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:155
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:168
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:152
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:139
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:166
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:140
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:146
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:137
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:158
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:153
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:169
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:143
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:145
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:163
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:147
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:159
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:164
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:141
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:161
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:149
Definition: block_fmha_pipeline_problem.hpp:281
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:290
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:287
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:285
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:295
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:305
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:299
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:289
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:297
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:302
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:291
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:283
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:282
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:284
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:303
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:292
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:296
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:307
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:304
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:288
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:293
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:306
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:286
Definition: block_fmha_pipeline_problem.hpp:29
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:55
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:42
static constexpr bool kHasDropout
Definition: block_fmha_pipeline_problem.hpp:62
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:61
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:58
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:60
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:41
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:39
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:59
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:46
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:44
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:56
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:33
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:37
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:50
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:31
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:64
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition: block_fmha_pipeline_problem.hpp:36
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:40
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:48
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:57
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:38
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:32
static constexpr bool kUseTrLoad
Definition: block_fmha_pipeline_problem.hpp:51
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:34
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:54
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:35
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:43
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:47
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:63
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:30
Definition: block_fmha_pipeline_problem.hpp:192
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:197
static constexpr index_t kNumWarps
Definition: block_fmha_pipeline_problem.hpp:220
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:198
static constexpr index_t kHeadDimV
Definition: block_fmha_pipeline_problem.hpp:202
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:221
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:180
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:216
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:203
static constexpr index_t kMaxSplits
Definition: block_fmha_pipeline_problem.hpp:217
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:213
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:214
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:215
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:195
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:212
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:178
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:196
Definition: block_fmha_pipeline_problem.hpp:175
static constexpr index_t NThreads
Definition: block_fmha_pipeline_problem.hpp:179
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:180
static constexpr index_t MaxVectorSize
Definition: block_fmha_pipeline_problem.hpp:176
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:178
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17