/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File
tile_fmha_traits.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
13  bool kPadSeqLenK_ /* padding for seqlen_k */,
14  bool kPadHeadDimQ_ /* paddding for hdim_q */,
15  bool kPadHeadDimV_ /* paddding for hdim_v */,
16  bool kHasLogitsSoftCap_,
17  BlockAttentionBiasEnum BiasEnum_,
18  bool kHasBiasGrad_,
19  bool kStoreLSE_,
20  bool kHasDropout_,
21  bool kDoFp8StaticQuant_,
22  index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
23  bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
25 {
26  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
27  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
28  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
29  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
30  static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
31  static constexpr auto BiasEnum = BiasEnum_;
32  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
33  static constexpr bool kStoreLSE = kStoreLSE_;
34  static constexpr bool kHasDropout = kHasDropout_;
35  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
36  static constexpr index_t kBlockPerCu = kBlockPerCu_;
37  static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
38 };
39 
40 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
41  bool kPadSeqLenK_ /* padding for seqlen_k */,
42  bool kPadHeadDimQ_ /* paddding for hdim_q */,
43  bool kPadHeadDimV_ /* paddding for hdim_v */,
44  bool kHasLogitsSoftCap_,
45  BlockAttentionBiasEnum BiasEnum_,
46  bool kHasBiasGrad_,
47  bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
48  bool kIsPagedKV_,
49  bool kDoFp8StaticQuant_,
50  index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
51  bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
53 {
54  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
55  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
56  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
57  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
58  static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
59  static constexpr auto BiasEnum = BiasEnum_;
60  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
61  static constexpr bool kStoreLSE = kStoreLSE_;
62  static constexpr bool kIsPagedKV = kIsPagedKV_;
63  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
64  static constexpr index_t kBlockPerCu = kBlockPerCu_;
65  static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
66 };
67 
68 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
69  bool kPadSeqLenK_ /* padding for seqlen_k */,
70  bool kPadHeadDimQ_ /* paddding for hdim_q */,
71  bool kPadHeadDimV_ /* paddding for hdim_v */,
72  bool kHasLogitsSoftCap_,
73  BlockAttentionBiasEnum BiasEnum_,
74  bool kHasBiasGrad_,
75  bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
76  bool kDoFp8StaticQuant_,
77  bool kIsPagedKV_,
78  bool kHasUnevenSplits_,
79  bool kMergeNumHeadGroupsSeqLenQ_ = false,
80  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
82 {
83  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
84  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
85  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
86  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
87  static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
88  static constexpr auto BiasEnum = BiasEnum_;
89  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
90  static constexpr bool kStoreLSE = kStoreLSE_;
91  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
92  static constexpr bool kIsPagedKV = kIsPagedKV_;
93  // determine if some split (length) is not divisible by tile size
94  static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
95  static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
96  static constexpr index_t kBlockPerCu = kBlockPerCu_;
97 };
98 
99 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
100  bool kPadHeadDimV_ /* paddding for hdim_v */,
101  bool kStoreLSE_,
102  bool kDoFp8StaticQuant_,
103  index_t kLogMaxSplits_,
104  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
106 {
107  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
108  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
109  static constexpr bool kStoreLSE = kStoreLSE_;
110  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
111 
112  static constexpr index_t kMaxSplits = (1 << kLogMaxSplits_);
113  static_assert(kMaxSplits <= get_warp_size() || kMaxSplits % get_warp_size() == 0);
114  static constexpr index_t kBlockPerCu = kBlockPerCu_;
115 };
116 
117 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
118  bool kPadSeqLenK_ /* padding for seqlen_k */,
119  bool kPadHeadDimQ_ /* paddding for hdim_q */,
120  bool kPadHeadDimV_ /* paddding for hdim_v */,
121  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
123 {
124  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
125  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
126  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
127  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
128  static constexpr index_t kBlockPerCu = kBlockPerCu_;
129 };
130 
131 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
132  bool kPadHeadDimV_ /* paddding for hdim_v */,
133  index_t kBlockPerCu_ = 2 /* hint to occupancy */>
135 {
136  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
137  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
138  static constexpr index_t kBlockPerCu = kBlockPerCu_;
139 };
140 
141 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
142  bool kPadHeadDimQ_ /* paddding for hdim_q */,
143  index_t kBlockPerCu_ = 2 /* hint to occupancy */>
145 {
146  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
147  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
148  static constexpr index_t kBlockPerCu = kBlockPerCu_;
149 };
150 
151 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
152  bool kPadSeqLenK_ /* padding for seqlen_k */,
153  bool kPadHeadDimQ_ /* paddding for hdim_q */,
154  bool kPadHeadDimV_ /* paddding for hdim_v */,
155  bool kStoreLSE_,
156  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
158 {
159  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
160  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
161  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
162  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
163  static constexpr bool kStoreLSE = kStoreLSE_;
164  static constexpr index_t kBlockPerCu = kBlockPerCu_;
165 };
166 
167 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
BlockAttentionBiasEnum
Definition: block_attention_bias_enum.hpp:12
int32_t index_t
Definition: integer.hpp:9
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
Definition: tile_fmha_traits.hpp:145
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:148
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:147
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:146
Definition: tile_fmha_traits.hpp:135
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:138
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:136
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:137
Definition: tile_fmha_traits.hpp:123
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:126
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:125
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:128
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:124
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:127
Definition: tile_fmha_traits.hpp:53
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:59
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:63
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:56
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:64
static constexpr bool kHasLogitsSoftCap
Definition: tile_fmha_traits.hpp:58
static constexpr bool kSkipMinSeqlenQ
Definition: tile_fmha_traits.hpp:65
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:61
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:54
static constexpr bool kIsPagedKV
Definition: tile_fmha_traits.hpp:62
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:57
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:55
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:60
Definition: tile_fmha_traits.hpp:106
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:107
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:108
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:110
static constexpr index_t kMaxSplits
Definition: tile_fmha_traits.hpp:112
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:109
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:114
Definition: tile_fmha_traits.hpp:82
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:96
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:90
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: tile_fmha_traits.hpp:95
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:85
static constexpr bool kHasLogitsSoftCap
Definition: tile_fmha_traits.hpp:87
static constexpr bool kHasUnevenSplits
Definition: tile_fmha_traits.hpp:94
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:84
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:89
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:91
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:88
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:86
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:83
static constexpr bool kIsPagedKV
Definition: tile_fmha_traits.hpp:92
Definition: tile_fmha_traits.hpp:158
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:163
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:161
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:160
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:162
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:164
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:159
Definition: tile_fmha_traits.hpp:25
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:35
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:32
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:28
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:26
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:29
static constexpr bool kHasDropout
Definition: tile_fmha_traits.hpp:34
static constexpr bool kHasLogitsSoftCap
Definition: tile_fmha_traits.hpp:30
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:36
static constexpr bool kSkipMinSeqlenQ
Definition: tile_fmha_traits.hpp:37
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:27
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:33
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:31