/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File
block_fmha_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 template <typename QDataType_,
12  typename KDataType_,
13  typename VDataType_,
14  typename SaccDataType_,
15  typename SMPLComputeDataType_,
16  typename BiasDataType_,
17  typename RandValOutputDataType_,
18  typename LSEDataType_,
19  typename PDataType_,
20  typename OaccDataType_,
21  typename ODataType_,
22  typename BlockFmhaShape_,
23  bool kIsGroupMode_,
24  typename AttentionVariant_,
25  typename FmhaMask_,
26  bool kUseTrLoad_,
27  typename Traits_>
29 {
45 
46  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
47  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
48  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
49 
50  static constexpr bool kIsGroupMode = kIsGroupMode_;
51  static constexpr bool kUseTrLoad = kUseTrLoad_;
52 
53  // attributes from traits
54  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
55  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
56  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
57  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
58  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
59  static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
60  static constexpr auto BiasEnum = Traits::BiasEnum;
61  static constexpr bool kStoreLSE = Traits::kStoreLSE;
62  static constexpr bool kHasDropout = Traits::kHasDropout;
63  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
64  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
65 };
66 
67 template <typename QDataType_,
68  typename KDataType_,
69  typename VDataType_,
70  typename SaccDataType_,
71  typename SMPLComputeDataType_,
72  typename BiasDataType_,
73  typename LSEDataType_,
74  typename PDataType_,
75  typename OaccDataType_,
76  typename ODataType_,
77  typename BlockFmhaShape_,
78  bool kIsGroupMode_,
79  typename AttentionVariant_,
80  typename FmhaMask_,
81  typename Traits_>
83 {
98 
99  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
100  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
101  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
102 
103  static constexpr bool kIsGroupMode = kIsGroupMode_;
104 
105  // attributes from traits
106  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
107  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
108  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
109  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
110  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
111  static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
112  static constexpr auto BiasEnum = Traits::BiasEnum;
113  static constexpr bool kStoreLSE = Traits::kStoreLSE;
114  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
115  static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
116  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
117 };
118 
119 template <typename QDataType_,
120  typename KDataType_,
121  typename VDataType_,
122  typename SaccDataType_,
123  typename SMPLComputeDataType_,
124  typename BiasDataType_,
125  typename LSEDataType_,
126  typename PDataType_,
127  typename OaccDataType_,
128  typename ODataType_,
129  typename BlockFmhaShape_,
130  bool kIsGroupMode_,
131  typename AttentionVariant_,
132  typename FmhaMask_,
133  typename Traits_>
135 {
150 
151  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
152  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
153  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
154 
155  static constexpr bool kIsGroupMode = kIsGroupMode_;
156 
157  // attributes from traits
158  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
159  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
160  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
161  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
162  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
163  static constexpr auto BiasEnum = Traits::BiasEnum;
164  static constexpr bool kStoreLSE = Traits::kStoreLSE;
165  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
166  static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
167  static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
168  static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
169  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
170 };
171 
172 // extract tile size attributes to remove dependency on traits
173 template <typename OaccDataType_, ck_tile::index_t kN1_>
175 {
176  static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_);
177 
178  static constexpr index_t kN1 = kN1_;
179  static constexpr index_t NThreads = kN1 / MaxVectorSize;
180  static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp
181 };
182 
183 template <typename LSEDataType_,
184  typename OaccDataType_,
185  typename ODataType_,
186  index_t HeadDimV_,
187  bool kIsGroupMode_,
188  ck_tile::index_t kN1_,
189  typename Traits_>
191  : BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>
192 {
194 
199 
200  static_assert(std::is_same_v<LSEDataType, OaccDataType>);
201 
202  static constexpr index_t kHeadDimV = HeadDimV_;
203  static constexpr bool kIsGroupMode = kIsGroupMode_;
204 
205  using BaseType::kM0;
206  using BaseType::kN1;
207 
208  static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
209 
210  // attributes from traits
211  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
212  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
213  static constexpr bool kStoreLSE = Traits::kStoreLSE;
214  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
215  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
216  static constexpr index_t kMaxSplits = Traits::kMaxSplits;
217  static_assert(8 <= kMaxSplits);
218 
219  static constexpr index_t kNumWarps = 4; // always use 4 warps for each workgroup
220  static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
221 
222  static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
223  (kM0 * kMaxSplits) % get_warp_size() == 0);
224 };
225 
226 template <typename QDataType_,
227  typename KDataType_,
228  typename VDataType_,
229  index_t kM0_,
230  index_t kN0_,
231  index_t kK0_,
232  index_t kN1_,
233  bool kIsVLayoutRowMajor_,
234  RotaryEmbeddingEnum RotaryEnum_,
235  bool kIsPagedKV_,
236  typename Traits_>
238 {
243 
244  static constexpr index_t kBlockSize = 256;
245 
246  static constexpr index_t kM0 = kM0_;
247  static constexpr index_t kN0 = kN0_;
248  static constexpr index_t kK0 = kK0_;
249  static constexpr index_t kN1 = kN1_;
250 
251  using VLayout = std::conditional_t<kIsVLayoutRowMajor_,
254 
255  static constexpr auto RotaryEnum = RotaryEnum_;
256  static constexpr bool kIsPagedKV = kIsPagedKV_;
257 
258  // attributes from traits
259  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
260  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
261  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
262  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
263  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
264 };
265 
266 template <typename QDataType_,
267  typename KDataType_,
268  typename VDataType_,
269  typename SaccDataType_,
270  typename SMPLComputeDataType_,
271  typename LSEDataType_,
272  typename PDataType_,
273  typename OaccDataType_,
274  typename ODataType_,
275  typename BlockFmhaShape_,
276  bool kIsGroupMode_,
277  typename FmhaMask_,
278  typename Traits_>
280 {
293 
294  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
295  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
296  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
297 
298  static constexpr bool kIsGroupMode = kIsGroupMode_;
299 
300  // attributes from traits
301  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
302  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
303  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
304  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
305  static constexpr bool kStoreLSE = Traits::kStoreLSE;
306  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
307 };
308 
309 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
RotaryEmbeddingEnum
Definition: block_rotary_embedding.hpp:12
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_fmha_pipeline_problem.hpp:238
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:239
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:260
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:261
std::conditional_t< kIsVLayoutRowMajor_, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: block_fmha_pipeline_problem.hpp:253
static constexpr auto RotaryEnum
Definition: block_fmha_pipeline_problem.hpp:255
static constexpr index_t kK0
Definition: block_fmha_pipeline_problem.hpp:248
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:259
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:242
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:256
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:241
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:246
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:249
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:263
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:244
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:262
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:240
static constexpr index_t kN0
Definition: block_fmha_pipeline_problem.hpp:247
Definition: block_fmha_pipeline_problem.hpp:83
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:88
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:108
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:114
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:97
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:116
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:94
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:109
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:86
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:95
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:96
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:91
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:87
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:85
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:112
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:84
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:89
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:92
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:110
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:106
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:101
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:107
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:93
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:99
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:103
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:90
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:115
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:113
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:111
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:100
Definition: block_fmha_pipeline_problem.hpp:135
static constexpr bool kHasUnevenSplits
Definition: block_fmha_pipeline_problem.hpp:167
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:138
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:162
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:148
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:160
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:165
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:151
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:136
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:144
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:142
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:155
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:168
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:152
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:139
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:166
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:140
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:146
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:137
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:158
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:153
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:169
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:143
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:145
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:163
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:147
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:159
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:164
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:141
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:161
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:149
Definition: block_fmha_pipeline_problem.hpp:280
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:289
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:286
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:284
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:294
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:304
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:298
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:288
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:296
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:301
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:290
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:282
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:281
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:283
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:302
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:291
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:295
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:306
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:303
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:287
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:292
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:305
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:285
Definition: block_fmha_pipeline_problem.hpp:29
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:55
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:42
static constexpr bool kHasDropout
Definition: block_fmha_pipeline_problem.hpp:62
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:61
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:58
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:60
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:41
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:39
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:59
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:46
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:44
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:56
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:33
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:37
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:50
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:31
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:64
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition: block_fmha_pipeline_problem.hpp:36
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:40
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:48
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:57
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:38
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:32
static constexpr bool kUseTrLoad
Definition: block_fmha_pipeline_problem.hpp:51
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:34
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:54
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:35
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:43
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:47
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:63
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:30
Definition: block_fmha_pipeline_problem.hpp:192
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:197
static constexpr index_t kNumWarps
Definition: block_fmha_pipeline_problem.hpp:219
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:198
static constexpr index_t kHeadDimV
Definition: block_fmha_pipeline_problem.hpp:202
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:220
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:180
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:215
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:203
static constexpr index_t kMaxSplits
Definition: block_fmha_pipeline_problem.hpp:216
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:212
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:213
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:214
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:195
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:211
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:178
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:196
Definition: block_fmha_pipeline_problem.hpp:175
static constexpr index_t NThreads
Definition: block_fmha_pipeline_problem.hpp:179
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:180
static constexpr index_t MaxVectorSize
Definition: block_fmha_pipeline_problem.hpp:176
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:178
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17