include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File

include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File#

Composable Kernel: include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File
block_fmha_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 template <typename QDataType_,
11  typename KDataType_,
12  typename VDataType_,
13  typename SaccDataType_,
14  typename SMPLComputeDataType_,
15  typename BiasDataType_,
16  typename RandValOutputDataType_,
17  typename LSEDataType_,
18  typename PDataType_,
19  typename OaccDataType_,
20  typename ODataType_,
21  typename BlockFmhaShape_,
22  bool kIsGroupMode_,
23  typename FmhaMask_,
24  typename Traits_>
26 {
41 
42  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
43  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
44  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
45 
46  static constexpr bool kIsGroupMode = kIsGroupMode_;
47 
48  // attributes from traits
49  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
50  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
51  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
52  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
53  static constexpr auto BiasEnum = Traits::BiasEnum;
54  static constexpr bool kStoreLSE = Traits::kStoreLSE;
55  static constexpr bool kHasDropout = Traits::kHasDropout;
56  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
57  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
58 };
59 
60 template <typename QDataType_,
61  typename KDataType_,
62  typename VDataType_,
63  typename SaccDataType_,
64  typename SMPLComputeDataType_,
65  typename BiasDataType_,
66  typename LSEDataType_,
67  typename PDataType_,
68  typename OaccDataType_,
69  typename ODataType_,
70  typename BlockFmhaShape_,
71  bool kIsGroupMode_,
72  typename FmhaMask_,
73  typename Traits_>
75 {
89 
90  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
91  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
92  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
93 
94  static constexpr bool kIsGroupMode = kIsGroupMode_;
95 
96  // attributes from traits
97  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
98  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
99  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
100  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
101  static constexpr auto BiasEnum = Traits::BiasEnum;
102  static constexpr bool kStoreLSE = Traits::kStoreLSE;
103  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
104  static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
105  static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
106  static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
107  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
108 };
109 
110 // extract tile size attributes to remove dependency on traits
111 template <typename OaccDataType_, ck_tile::index_t kN1_>
113 {
114  static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_);
115 
116  static constexpr index_t kN1 = kN1_;
117  static constexpr index_t NThreads = kN1 / MaxVectorSize;
118  static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp
119 };
120 
121 template <typename LSEDataType_,
122  typename OaccDataType_,
123  typename ODataType_,
124  index_t HeadDimV_,
125  bool kIsGroupMode_,
126  ck_tile::index_t kN1_,
127  typename Traits_>
129  : BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>
130 {
132 
137 
138  static_assert(std::is_same_v<LSEDataType, OaccDataType>);
139 
140  static constexpr index_t kHeadDimV = HeadDimV_;
141  static constexpr bool kIsGroupMode = kIsGroupMode_;
142 
143  using BaseType::kM0;
144  using BaseType::kN1;
145 
146  static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
147 
148  // attributes from traits
149  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
150  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
151  static constexpr bool kStoreLSE = Traits::kStoreLSE;
152  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
153  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
154  static constexpr index_t kMaxSplits = Traits::kMaxSplits;
155  static_assert(8 <= kMaxSplits);
156 
157  static constexpr index_t kNumWarps = 4; // always use 4 warps for each workgroup
158  static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
159 
160  static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
161  (kM0 * kMaxSplits) % get_warp_size() == 0);
162 };
163 
164 template <typename QDataType_,
165  typename KDataType_,
166  typename VDataType_,
167  index_t kM0_,
168  index_t kN0_,
169  index_t kK0_,
170  index_t kN1_,
171  bool kIsVLayoutRowMajor_,
172  RotaryEmbeddingEnum RotaryEnum_,
173  bool kIsPagedKV_,
174  typename Traits_>
176 {
181 
182  static constexpr index_t kBlockSize = 256;
183 
184  static constexpr index_t kM0 = kM0_;
185  static constexpr index_t kN0 = kN0_;
186  static constexpr index_t kK0 = kK0_;
187  static constexpr index_t kN1 = kN1_;
188 
189  using VLayout = std::conditional_t<kIsVLayoutRowMajor_,
192 
193  static constexpr auto RotaryEnum = RotaryEnum_;
194  static constexpr bool kIsPagedKV = kIsPagedKV_;
195 
196  // attributes from traits
197  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
198  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
199  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
200  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
201  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
202 };
203 
204 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
RotaryEmbeddingEnum
Definition: block_rotary_embedding.hpp:12
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_fmha_pipeline_problem.hpp:176
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:177
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:198
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:199
std::conditional_t< kIsVLayoutRowMajor_, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: block_fmha_pipeline_problem.hpp:191
static constexpr auto RotaryEnum
Definition: block_fmha_pipeline_problem.hpp:193
static constexpr index_t kK0
Definition: block_fmha_pipeline_problem.hpp:186
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:197
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:180
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:194
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:179
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:184
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:187
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:201
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:182
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:200
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:178
static constexpr index_t kN0
Definition: block_fmha_pipeline_problem.hpp:185
Definition: block_fmha_pipeline_problem.hpp:75
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:107
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:87
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:79
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:85
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:97
static constexpr bool kHasUnevenSplits
Definition: block_fmha_pipeline_problem.hpp:105
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:76
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:92
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:102
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:83
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:90
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:103
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:81
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:98
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:80
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:88
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:94
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:99
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:101
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:77
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:104
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:84
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:86
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:78
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:91
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:82
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:106
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:100
Definition: block_fmha_pipeline_problem.hpp:26
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:54
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:40
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:37
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:28
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:34
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:32
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:39
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:29
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:51
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:56
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:52
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:38
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:53
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:36
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:31
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition: block_fmha_pipeline_problem.hpp:33
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:44
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:57
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:27
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:46
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:43
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:35
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:30
static constexpr bool kHasDropout
Definition: block_fmha_pipeline_problem.hpp:55
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:49
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:50
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:42
Definition: block_fmha_pipeline_problem.hpp:130
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:135
static constexpr index_t kNumWarps
Definition: block_fmha_pipeline_problem.hpp:157
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:136
static constexpr index_t kHeadDimV
Definition: block_fmha_pipeline_problem.hpp:140
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:158
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:118
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:153
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:141
static constexpr index_t kMaxSplits
Definition: block_fmha_pipeline_problem.hpp:154
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:150
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:151
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:152
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:133
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:149
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:116
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:134
Definition: block_fmha_pipeline_problem.hpp:113
static constexpr index_t NThreads
Definition: block_fmha_pipeline_problem.hpp:117
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:118
static constexpr index_t MaxVectorSize
Definition: block_fmha_pipeline_problem.hpp:114
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:116
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17