/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File
tile_fmha_traits.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
13  bool kPadSeqLenK_ /* padding for seqlen_k */,
14  bool kPadHeadDimQ_ /* paddding for hdim_q */,
15  bool kPadHeadDimV_ /* paddding for hdim_v */,
16  bool kHasLogitsSoftCap_,
17  BlockAttentionBiasEnum BiasEnum_,
18  bool kHasBiasGrad_,
19  bool kStoreLSE_,
20  bool kHasDropout_,
21  bool kDoFp8StaticQuant_,
22  index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
23  bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
25 {
26  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
27  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
28  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
29  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
30  static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
31  static constexpr auto BiasEnum = BiasEnum_;
32  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
33  static constexpr bool kStoreLSE = kStoreLSE_;
34  static constexpr bool kHasDropout = kHasDropout_;
35  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
36  static constexpr index_t kBlockPerCu = kBlockPerCu_;
37  static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
38 };
39 
40 template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
41  index_t kPadHeadDimV_ /* paddding for hdim_v */,
42  BlockAttentionBiasEnum BiasEnum_,
43  bool kHasBiasGrad_,
44  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
46 {
47  static constexpr index_t kPadHeadDimQ = kPadHeadDimQ_;
48  static constexpr index_t kPadHeadDimV = kPadHeadDimV_;
49  static constexpr auto BiasEnum = BiasEnum_;
50  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
51  static constexpr index_t kBlockPerCu = kBlockPerCu_;
52 
53  static_assert(kPadHeadDimQ == 0 || kPadHeadDimQ == 8 || kPadHeadDimQ == 1);
54  static_assert(kPadHeadDimV == 0 || kPadHeadDimV == 8 || kPadHeadDimV == 1);
55 };
56 
57 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
58  bool kPadSeqLenK_ /* padding for seqlen_k */,
59  bool kPadHeadDimQ_ /* paddding for hdim_q */,
60  bool kPadHeadDimV_ /* paddding for hdim_v */,
61  bool kHasLogitsSoftCap_,
62  BlockAttentionBiasEnum BiasEnum_,
63  bool kHasBiasGrad_,
64  bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
65  bool kIsPagedKV_,
66  bool kDoFp8StaticQuant_,
67  index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
68  bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
70 {
71  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
72  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
73  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
74  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
75  static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
76  static constexpr auto BiasEnum = BiasEnum_;
77  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
78  static constexpr bool kStoreLSE = kStoreLSE_;
79  static constexpr bool kIsPagedKV = kIsPagedKV_;
80  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
81  static constexpr index_t kBlockPerCu = kBlockPerCu_;
82  static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
83 };
84 
85 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
86  bool kPadSeqLenK_ /* padding for seqlen_k */,
87  bool kPadHeadDimQ_ /* paddding for hdim_q */,
88  bool kPadHeadDimV_ /* paddding for hdim_v */,
89  bool kHasLogitsSoftCap_,
90  BlockAttentionBiasEnum BiasEnum_,
91  bool kHasBiasGrad_,
92  bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
93  bool kDoFp8StaticQuant_,
94  bool kIsPagedKV_,
95  bool kHasUnevenSplits_,
96  bool kMergeNumHeadGroupsSeqLenQ_ = false,
97  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
99 {
100  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
101  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
102  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
103  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
104  static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
105  static constexpr auto BiasEnum = BiasEnum_;
106  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
107  static constexpr bool kStoreLSE = kStoreLSE_;
108  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
109  static constexpr bool kIsPagedKV = kIsPagedKV_;
110  // determine if some split (length) is not divisible by tile size
111  static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
112  static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
113  static constexpr index_t kBlockPerCu = kBlockPerCu_;
114 };
115 
116 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
117  bool kPadHeadDimV_ /* paddding for hdim_v */,
118  bool kStoreLSE_,
119  bool kDoFp8StaticQuant_,
120  index_t kLogMaxSplits_,
121  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
123 {
124  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
125  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
126  static constexpr bool kStoreLSE = kStoreLSE_;
127  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
128 
129  static constexpr index_t kMaxSplits = (1 << kLogMaxSplits_);
130  static_assert(kMaxSplits <= get_warp_size() || kMaxSplits % get_warp_size() == 0);
131  static constexpr index_t kBlockPerCu = kBlockPerCu_;
132 };
133 
134 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
135  bool kPadSeqLenK_ /* padding for seqlen_k */,
136  bool kPadHeadDimQ_ /* paddding for hdim_q */,
137  bool kPadHeadDimV_ /* paddding for hdim_v */,
138  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
140 {
141  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
142  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
143  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
144  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
145  static constexpr index_t kBlockPerCu = kBlockPerCu_;
146 };
147 
148 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
149  bool kPadHeadDimV_ /* paddding for hdim_v */,
150  index_t kBlockPerCu_ = 2 /* hint to occupancy */>
152 {
153  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
154  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
155  static constexpr index_t kBlockPerCu = kBlockPerCu_;
156 };
157 
158 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
159  bool kPadHeadDimQ_ /* paddding for hdim_q */,
160  index_t kBlockPerCu_ = 2 /* hint to occupancy */>
162 {
163  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
164  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
165  static constexpr index_t kBlockPerCu = kBlockPerCu_;
166 };
167 
168 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
169  bool kPadSeqLenK_ /* padding for seqlen_k */,
170  bool kPadHeadDimQ_ /* paddding for hdim_q */,
171  bool kPadHeadDimV_ /* paddding for hdim_v */,
172  bool kStoreLSE_,
173  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
175 {
176  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
177  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
178  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
179  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
180  static constexpr bool kStoreLSE = kStoreLSE_;
181  static constexpr index_t kBlockPerCu = kBlockPerCu_;
182 };
183 
184 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
BlockAttentionBiasEnum
Definition: block_attention_bias_enum.hpp:12
int32_t index_t
Definition: integer.hpp:9
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: tile_fmha_traits.hpp:162
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:165
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:164
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:163
Definition: tile_fmha_traits.hpp:152
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:155
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:153
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:154
Definition: tile_fmha_traits.hpp:46
static constexpr index_t kPadHeadDimQ
Definition: tile_fmha_traits.hpp:47
static constexpr index_t kPadHeadDimV
Definition: tile_fmha_traits.hpp:48
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:50
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:49
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:51
Definition: tile_fmha_traits.hpp:140
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:143
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:142
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:145
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:141
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:144
Definition: tile_fmha_traits.hpp:70
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:76
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:80
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:73
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:81
static constexpr bool kHasLogitsSoftCap
Definition: tile_fmha_traits.hpp:75
static constexpr bool kSkipMinSeqlenQ
Definition: tile_fmha_traits.hpp:82
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:78
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:71
static constexpr bool kIsPagedKV
Definition: tile_fmha_traits.hpp:79
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:74
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:72
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:77
Definition: tile_fmha_traits.hpp:123
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:124
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:125
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:127
static constexpr index_t kMaxSplits
Definition: tile_fmha_traits.hpp:129
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:126
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:131
Definition: tile_fmha_traits.hpp:99
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:113
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:107
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: tile_fmha_traits.hpp:112
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:102
static constexpr bool kHasLogitsSoftCap
Definition: tile_fmha_traits.hpp:104
static constexpr bool kHasUnevenSplits
Definition: tile_fmha_traits.hpp:111
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:101
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:106
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:108
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:105
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:103
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:100
static constexpr bool kIsPagedKV
Definition: tile_fmha_traits.hpp:109
Definition: tile_fmha_traits.hpp:175
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:180
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:178
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:177
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:179
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:181
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:176
Definition: tile_fmha_traits.hpp:25
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:35
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:32
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:28
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:26
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:29
static constexpr bool kHasDropout
Definition: tile_fmha_traits.hpp:34
static constexpr bool kHasLogitsSoftCap
Definition: tile_fmha_traits.hpp:30
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:36
static constexpr bool kSkipMinSeqlenQ
Definition: tile_fmha_traits.hpp:37
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:27
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:33
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:31