/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File#
block_fmha_pipeline_problem.hpp
Go to the documentation of this file.
Definition: cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_fmha_pipeline_problem.hpp:238
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:239
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:260
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:261
std::conditional_t< kIsVLayoutRowMajor_, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: block_fmha_pipeline_problem.hpp:253
static constexpr auto RotaryEnum
Definition: block_fmha_pipeline_problem.hpp:255
static constexpr index_t kK0
Definition: block_fmha_pipeline_problem.hpp:248
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:259
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:242
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:256
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:241
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:246
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:249
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:263
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:244
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:262
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:240
static constexpr index_t kN0
Definition: block_fmha_pipeline_problem.hpp:247
Definition: block_fmha_pipeline_problem.hpp:83
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:88
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:108
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:114
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:97
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:116
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:94
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:109
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:86
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:95
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:96
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:91
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:87
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:85
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:112
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:84
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:89
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:92
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:110
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:106
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:101
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:107
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:93
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:99
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:103
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:90
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:115
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:113
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:111
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:100
Definition: block_fmha_pipeline_problem.hpp:135
static constexpr bool kHasUnevenSplits
Definition: block_fmha_pipeline_problem.hpp:167
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:138
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:162
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:148
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:160
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:165
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:151
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:136
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:144
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:142
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:155
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:168
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:152
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:139
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:166
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:140
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:146
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:137
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:158
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:153
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:169
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:143
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:145
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:163
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:147
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:159
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:164
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:141
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:161
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:149
Definition: block_fmha_pipeline_problem.hpp:280
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:289
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:286
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:284
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:294
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:304
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:298
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:288
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:296
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:301
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:290
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:282
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:281
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:283
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:302
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:291
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:295
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:306
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:303
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:287
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:292
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:305
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:285
Definition: block_fmha_pipeline_problem.hpp:29
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:55
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:42
static constexpr bool kHasDropout
Definition: block_fmha_pipeline_problem.hpp:62
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:61
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:58
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:60
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:41
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:39
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:59
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:46
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:44
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:56
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:33
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:37
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:50
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:31
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:64
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition: block_fmha_pipeline_problem.hpp:36
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:40
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:48
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:57
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:38
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:32
static constexpr bool kUseTrLoad
Definition: block_fmha_pipeline_problem.hpp:51
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:34
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:54
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:35
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:43
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:47
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:63
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:30
Definition: block_fmha_pipeline_problem.hpp:192
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:197
static constexpr index_t kNumWarps
Definition: block_fmha_pipeline_problem.hpp:219
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:198
static constexpr index_t kHeadDimV
Definition: block_fmha_pipeline_problem.hpp:202
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:220
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:180
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:215
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:203
static constexpr index_t kMaxSplits
Definition: block_fmha_pipeline_problem.hpp:216
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:212
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:213
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:214
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:195
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:211
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:178
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:196
Definition: block_fmha_pipeline_problem.hpp:175
static constexpr index_t NThreads
Definition: block_fmha_pipeline_problem.hpp:179
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:180
static constexpr index_t MaxVectorSize
Definition: block_fmha_pipeline_problem.hpp:176
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:178
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17