include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File#
block_fmha_pipeline_problem.hpp
Go to the documentation of this file.
Definition: cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_fmha_pipeline_problem.hpp:176
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:177
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:198
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:199
std::conditional_t< kIsVLayoutRowMajor_, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: block_fmha_pipeline_problem.hpp:191
static constexpr auto RotaryEnum
Definition: block_fmha_pipeline_problem.hpp:193
static constexpr index_t kK0
Definition: block_fmha_pipeline_problem.hpp:186
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:197
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:180
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:194
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:179
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:184
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:187
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:201
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:182
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:200
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:178
static constexpr index_t kN0
Definition: block_fmha_pipeline_problem.hpp:185
Definition: block_fmha_pipeline_problem.hpp:75
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:107
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:87
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:79
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:85
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:97
static constexpr bool kHasUnevenSplits
Definition: block_fmha_pipeline_problem.hpp:105
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:76
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:92
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:102
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:83
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:90
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:103
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:81
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:98
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:80
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:88
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:94
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:99
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:101
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:77
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:104
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:84
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:86
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:78
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:91
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:82
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:106
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:100
Definition: block_fmha_pipeline_problem.hpp:26
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:54
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:40
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:37
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:28
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:34
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:32
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:39
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:29
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:51
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:56
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:52
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:38
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:53
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:36
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:31
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition: block_fmha_pipeline_problem.hpp:33
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:44
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:57
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:27
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:46
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:43
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:35
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:30
static constexpr bool kHasDropout
Definition: block_fmha_pipeline_problem.hpp:55
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:49
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:50
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:42
Definition: block_fmha_pipeline_problem.hpp:130
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:135
static constexpr index_t kNumWarps
Definition: block_fmha_pipeline_problem.hpp:157
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:136
static constexpr index_t kHeadDimV
Definition: block_fmha_pipeline_problem.hpp:140
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:158
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:118
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:153
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:141
static constexpr index_t kMaxSplits
Definition: block_fmha_pipeline_problem.hpp:154
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:150
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:151
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:152
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:133
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:149
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:116
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:134
Definition: block_fmha_pipeline_problem.hpp:113
static constexpr index_t NThreads
Definition: block_fmha_pipeline_problem.hpp:117
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:118
static constexpr index_t MaxVectorSize
Definition: block_fmha_pipeline_problem.hpp:114
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:116
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17