/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File#
block_fmha_pipeline_problem.hpp
Go to the documentation of this file.
Definition: cluster_descriptor.hpp:13
@ VECTORIZED_LAYOUT
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_fmha_pipeline_problem.hpp:105
static constexpr index_t kVectorSize
Definition: block_fmha_pipeline_problem.hpp:121
static constexpr auto kKVLookupTable
Definition: block_fmha_pipeline_problem.hpp:123
static constexpr auto kKVMemoryLayout
Definition: block_fmha_pipeline_problem.hpp:122
static constexpr index_t kLog2PageSize
Definition: block_fmha_pipeline_problem.hpp:110
static constexpr bool kIsVectorizedLayout
Definition: block_fmha_pipeline_problem.hpp:124
static constexpr index_t kPageBlockSize
Definition: block_fmha_pipeline_problem.hpp:106
Definition: block_fmha_pipeline_problem.hpp:308
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:309
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:330
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:331
std::conditional_t< kIsVLayoutRowMajor_, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: block_fmha_pipeline_problem.hpp:323
static constexpr auto RotaryEnum
Definition: block_fmha_pipeline_problem.hpp:325
static constexpr index_t kK0
Definition: block_fmha_pipeline_problem.hpp:318
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:329
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:312
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:326
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:311
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:316
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:319
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:333
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:314
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:332
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:310
static constexpr index_t kN0
Definition: block_fmha_pipeline_problem.hpp:317
Definition: block_fmha_pipeline_problem.hpp:150
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:155
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:175
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:181
static constexpr bool kHasSink
Definition: block_fmha_pipeline_problem.hpp:184
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:164
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:183
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:161
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:176
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:153
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:162
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:163
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:158
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:154
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:152
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:179
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:151
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:156
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:159
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:177
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:173
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:168
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:174
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:160
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:166
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:170
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:157
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:182
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:180
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:178
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:167
Definition: block_fmha_pipeline_problem.hpp:203
static constexpr bool kHasUnevenSplits
Definition: block_fmha_pipeline_problem.hpp:235
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:206
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:230
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:216
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:228
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:233
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:219
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:204
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:212
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:210
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:223
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:236
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:220
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:207
static constexpr bool kHasSink
Definition: block_fmha_pipeline_problem.hpp:238
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:234
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:208
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:214
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:205
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:226
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:221
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:237
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:211
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:213
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:231
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:215
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:227
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:232
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:209
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:229
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:217
Definition: block_fmha_pipeline_problem.hpp:30
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:56
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:43
static constexpr bool kHasDropout
Definition: block_fmha_pipeline_problem.hpp:63
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:62
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:59
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:61
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:42
static constexpr auto QScaleEnum
Definition: block_fmha_pipeline_problem.hpp:64
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:40
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:60
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:47
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:45
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:57
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:34
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:38
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:51
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:32
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:65
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition: block_fmha_pipeline_problem.hpp:37
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:41
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:49
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:58
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:39
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:33
static constexpr bool kUseTrLoad
Definition: block_fmha_pipeline_problem.hpp:52
static constexpr bool kHasSink
Definition: block_fmha_pipeline_problem.hpp:66
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:35
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:55
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:36
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:44
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:48
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:31
Definition: block_fmha_pipeline_problem.hpp:261
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:266
static constexpr index_t kNumWarps
Definition: block_fmha_pipeline_problem.hpp:289
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:267
static constexpr index_t kHeadDimV
Definition: block_fmha_pipeline_problem.hpp:271
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:290
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:249
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:285
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:272
static constexpr index_t kMaxSplits
Definition: block_fmha_pipeline_problem.hpp:286
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:282
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:283
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:284
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:264
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:281
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:247
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:265
Definition: block_fmha_pipeline_problem.hpp:244
static constexpr index_t NThreads
Definition: block_fmha_pipeline_problem.hpp:248
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:249
static constexpr index_t MaxVectorSize
Definition: block_fmha_pipeline_problem.hpp:245
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:247
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17