/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File
tile_fmha_traits.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
15  bool kPadSeqLenK_ /* padding for seqlen_k */,
16  bool kPadHeadDimQ_ /* paddding for hdim_q */,
17  bool kPadHeadDimV_ /* paddding for hdim_v */,
18  bool kHasLogitsSoftCap_,
19  BlockAttentionBiasEnum BiasEnum_,
20  bool kHasBiasGrad_,
21  bool kStoreLSE_,
22  bool kHasDropout_,
23  BlockAttentionQuantScaleEnum QScaleEnum_,
24  index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
25  bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
26  bool kHasSink_ = false>
28 {
29  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
30  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
31  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
32  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
33  static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
34  static constexpr auto BiasEnum = BiasEnum_;
35  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
36  static constexpr bool kStoreLSE = kStoreLSE_;
37  static constexpr bool kHasDropout = kHasDropout_;
38  static constexpr auto QScaleEnum = QScaleEnum_;
39  static constexpr index_t kBlockPerCu = kBlockPerCu_;
40  static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
41  static constexpr bool kHasSink = kHasSink_;
42 };
43 
44 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
45  bool kPadSeqLenK_ /* padding for seqlen_k */,
46  bool kPadHeadDimQ_ /* padding for hdim_q */,
47  bool kPadHeadDimV_ /* padding for hdim_v */,
48  bool kHasLogitsSoftCap_,
49  BlockAttentionBiasEnum BiasEnum_,
50  bool kHasBiasGrad_,
51  bool kStoreLSE_,
52  bool kHasDropout_,
53  BlockAttentionQuantScaleEnum QScaleEnum_,
54  index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
55  bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
56  index_t kPageBlockSize_ = 1,
57  BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
59  BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
61 struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
62  kPadSeqLenK_,
63  kPadHeadDimQ_,
64  kPadHeadDimV_,
65  kHasLogitsSoftCap_,
66  BiasEnum_,
67  kHasBiasGrad_,
68  kStoreLSE_,
69  kHasDropout_,
70  QScaleEnum_,
71  kBlockPerCu_,
72  kSkipMinSeqlenQ_,
73  false>
74 {
75  static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
76  static constexpr auto kKVLookupTable = kKVLookupTable_;
77  static constexpr index_t kPageBlockSize = kPageBlockSize_;
80  "Batch prefill only supports vectorized or linear KV cache layout.");
81  static_assert(kPageBlockSize > 0 && ((kPageBlockSize & (kPageBlockSize - 1)) == 0),
82  "kPageBlockSize should be a power of 2 to support efficient page-based KV cache "
83  "addressing.");
84 };
85 
86 template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
87  index_t kPadHeadDimV_ /* paddding for hdim_v */,
88  BlockAttentionBiasEnum BiasEnum_,
89  bool kHasBiasGrad_,
90  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
92 {
93  static constexpr index_t kPadHeadDimQ = kPadHeadDimQ_;
94  static constexpr index_t kPadHeadDimV = kPadHeadDimV_;
95  static constexpr auto BiasEnum = BiasEnum_;
96  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
97  static constexpr index_t kBlockPerCu = kBlockPerCu_;
98 
99  static_assert(kPadHeadDimQ == 0 || kPadHeadDimQ == 8 || kPadHeadDimQ == 1);
100  static_assert(kPadHeadDimV == 0 || kPadHeadDimV == 8 || kPadHeadDimV == 1);
101 };
102 
103 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
104  bool kPadSeqLenK_ /* padding for seqlen_k */,
105  bool kPadHeadDimQ_ /* paddding for hdim_q */,
106  bool kPadHeadDimV_ /* paddding for hdim_v */,
107  bool kHasLogitsSoftCap_,
108  BlockAttentionBiasEnum BiasEnum_,
109  bool kHasBiasGrad_,
110  bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
111  bool kIsPagedKV_,
112  bool kDoFp8StaticQuant_,
113  index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
114  bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
115  bool kHasSink_ = false>
117 {
118  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
119  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
120  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
121  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
122  static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
123  static constexpr auto BiasEnum = BiasEnum_;
124  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
125  static constexpr bool kStoreLSE = kStoreLSE_;
126  static constexpr bool kIsPagedKV = kIsPagedKV_;
127  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
128  static constexpr index_t kBlockPerCu = kBlockPerCu_;
129  static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
130  static constexpr bool kHasSink = kHasSink_;
131 };
132 
133 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
134  bool kPadSeqLenK_ /* padding for seqlen_k */,
135  bool kPadHeadDimQ_ /* paddding for hdim_q */,
136  bool kPadHeadDimV_ /* paddding for hdim_v */,
137  bool kHasLogitsSoftCap_,
138  BlockAttentionBiasEnum BiasEnum_,
139  bool kHasBiasGrad_,
140  bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
141  bool kDoFp8StaticQuant_,
142  bool kIsPagedKV_,
143  bool kHasUnevenSplits_,
144  bool kMergeNumHeadGroupsSeqLenQ_ = false,
145  index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
146  bool kHasSink_ = false>
148 {
149  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
150  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
151  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
152  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
153  static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
154  static constexpr auto BiasEnum = BiasEnum_;
155  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
156  static constexpr bool kStoreLSE = kStoreLSE_;
157  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
158  static constexpr bool kIsPagedKV = kIsPagedKV_;
159  // determine if some split (length) is not divisible by tile size
160  static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
161  static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
162  static constexpr index_t kBlockPerCu = kBlockPerCu_;
163  static constexpr bool kHasSink = kHasSink_;
164 };
165 
166 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
167  bool kPadHeadDimV_ /* paddding for hdim_v */,
168  bool kStoreLSE_,
169  bool kDoFp8StaticQuant_,
170  index_t kLogMaxSplits_,
171  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
173 {
174  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
175  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
176  static constexpr bool kStoreLSE = kStoreLSE_;
177  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
178 
179  static constexpr index_t kMaxSplits = (1 << kLogMaxSplits_);
180  static_assert(kMaxSplits <= get_warp_size() || kMaxSplits % get_warp_size() == 0);
181  static constexpr index_t kBlockPerCu = kBlockPerCu_;
182 };
183 
184 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
185  bool kPadSeqLenK_ /* padding for seqlen_k */,
186  bool kPadHeadDimQ_ /* paddding for hdim_q */,
187  bool kPadHeadDimV_ /* paddding for hdim_v */,
188  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
190 {
191  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
192  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
193  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
194  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
195  static constexpr index_t kBlockPerCu = kBlockPerCu_;
196 };
197 
198 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
199  bool kPadHeadDimV_ /* paddding for hdim_v */,
200  index_t kBlockPerCu_ = 2 /* hint to occupancy */>
202 {
203  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
204  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
205  static constexpr index_t kBlockPerCu = kBlockPerCu_;
206 };
207 
208 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
209  bool kPadHeadDimQ_ /* paddding for hdim_q */,
210  index_t kBlockPerCu_ = 2 /* hint to occupancy */>
212 {
213  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
214  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
215  static constexpr index_t kBlockPerCu = kBlockPerCu_;
216 };
217 
218 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
BlockAttentionBiasEnum
Definition: block_attention_bias_enum.hpp:12
BlockAttentionKVCacheMemoryLayoutEnum
Definition: block_attention_kvcache_layout_enum.hpp:18
int32_t index_t
Definition: integer.hpp:9
BlockAttentionKVCacheLookupTableEnum
Definition: block_attention_kvcache_layout_enum.hpp:27
BlockAttentionQuantScaleEnum
Definition: block_attention_quant_scale_enum.hpp:12
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: tile_fmha_traits.hpp:74
static constexpr auto kKVMemoryLayout
Definition: tile_fmha_traits.hpp:75
static constexpr index_t kPageBlockSize
Definition: tile_fmha_traits.hpp:77
static constexpr auto kKVLookupTable
Definition: tile_fmha_traits.hpp:76
Definition: tile_fmha_traits.hpp:212
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:215
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:214
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:213
Definition: tile_fmha_traits.hpp:202
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:205
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:203
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:204
Definition: tile_fmha_traits.hpp:92
static constexpr index_t kPadHeadDimQ
Definition: tile_fmha_traits.hpp:93
static constexpr index_t kPadHeadDimV
Definition: tile_fmha_traits.hpp:94
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:96
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:95
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:97
Definition: tile_fmha_traits.hpp:190
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:193
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:192
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:195
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:191
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:194
Definition: tile_fmha_traits.hpp:117
static constexpr bool kHasSink
Definition: tile_fmha_traits.hpp:130
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:123
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:120
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:124
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:118
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:125
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:127
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:121
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:128
static constexpr bool kSkipMinSeqlenQ
Definition: tile_fmha_traits.hpp:129
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:119
static constexpr bool kIsPagedKV
Definition: tile_fmha_traits.hpp:126
static constexpr bool kHasLogitsSoftCap
Definition: tile_fmha_traits.hpp:122
Definition: tile_fmha_traits.hpp:173
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:174
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:175
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:177
static constexpr index_t kMaxSplits
Definition: tile_fmha_traits.hpp:179
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:176
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:181
Definition: tile_fmha_traits.hpp:148
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:162
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:151
static constexpr bool kIsPagedKV
Definition: tile_fmha_traits.hpp:158
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:154
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:157
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:149
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: tile_fmha_traits.hpp:161
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:155
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:150
static constexpr bool kHasSink
Definition: tile_fmha_traits.hpp:163
static constexpr bool kHasUnevenSplits
Definition: tile_fmha_traits.hpp:160
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:152
static constexpr bool kHasLogitsSoftCap
Definition: tile_fmha_traits.hpp:153
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:156
Definition: tile_fmha_traits.hpp:28
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:29
static constexpr bool kSkipMinSeqlenQ
Definition: tile_fmha_traits.hpp:40
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:32
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:34
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:35
static constexpr bool kHasSink
Definition: tile_fmha_traits.hpp:41
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:31
static constexpr bool kHasDropout
Definition: tile_fmha_traits.hpp:37
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:39
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:36
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:30
static constexpr bool kHasLogitsSoftCap
Definition: tile_fmha_traits.hpp:33
static constexpr auto QScaleEnum
Definition: tile_fmha_traits.hpp:38