include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File

include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File#

Composable Kernel: include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File
tile_fmha_traits.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
13  bool kPadSeqLenK_ /* padding for seqlen_k */,
14  bool kPadHeadDimQ_ /* paddding for hdim_q */,
15  bool kPadHeadDimV_ /* paddding for hdim_v */,
16  BlockAttentionBiasEnum BiasEnum_,
17  bool kHasBiasGrad_,
18  bool kStoreLSE_,
19  bool kHasDropout_,
20  bool kDoFp8StaticQuant_,
21  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
23 {
24  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
25  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
26  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
27  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
28  static constexpr auto BiasEnum = BiasEnum_;
29  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
30  static constexpr bool kStoreLSE = kStoreLSE_;
31  static constexpr bool kHasDropout = kHasDropout_;
32  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
33  static constexpr index_t kBlockPerCu = kBlockPerCu_;
34 };
35 
36 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
37  bool kPadSeqLenK_ /* padding for seqlen_k */,
38  bool kPadHeadDimQ_ /* paddding for hdim_q */,
39  bool kPadHeadDimV_ /* paddding for hdim_v */,
40  BlockAttentionBiasEnum BiasEnum_,
41  bool kHasBiasGrad_,
42  bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
43  bool kDoFp8StaticQuant_,
44  bool kIsPagedKV_,
45  bool kHasUnevenSplits_,
46  bool kMergeNumHeadGroupsSeqLenQ_ = false,
47  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
49 {
50  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
51  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
52  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
53  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
54  static constexpr auto BiasEnum = BiasEnum_;
55  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
56  static constexpr bool kStoreLSE = kStoreLSE_;
57  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
58  static constexpr bool kIsPagedKV = kIsPagedKV_;
59  // determine if some split (length) is not divisible by tile size
60  static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
61  static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
62  static constexpr index_t kBlockPerCu = kBlockPerCu_;
63 };
64 
65 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
66  bool kPadHeadDimV_ /* paddding for hdim_v */,
67  bool kStoreLSE_,
68  bool kDoFp8StaticQuant_,
69  index_t kLogMaxSplits_,
70  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
72 {
73  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
74  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
75  static constexpr bool kStoreLSE = kStoreLSE_;
76  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
77 
78  static constexpr index_t kMaxSplits = (1 << kLogMaxSplits_);
79  static_assert(kMaxSplits <= get_warp_size() || kMaxSplits % get_warp_size() == 0);
80  static constexpr index_t kBlockPerCu = kBlockPerCu_;
81 };
82 
83 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
84  bool kPadSeqLenK_ /* padding for seqlen_k */,
85  bool kPadHeadDimQ_ /* paddding for hdim_q */,
86  bool kPadHeadDimV_ /* paddding for hdim_v */,
87  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
89 {
90  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
91  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
92  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
93  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
94  static constexpr index_t kBlockPerCu = kBlockPerCu_;
95 };
96 
97 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
98  bool kPadHeadDimV_ /* paddding for hdim_v */,
99  index_t kBlockPerCu_ = 2 /* hint to occupancy */>
101 {
102  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
103  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
104  static constexpr index_t kBlockPerCu = kBlockPerCu_;
105 };
106 
107 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
108  bool kPadHeadDimQ_ /* paddding for hdim_q */,
109  index_t kBlockPerCu_ = 2 /* hint to occupancy */>
111 {
112  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
113  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
114  static constexpr index_t kBlockPerCu = kBlockPerCu_;
115 };
116 
117 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
BlockAttentionBiasEnum
Definition: block_attention_bias_enum.hpp:12
int32_t index_t
Definition: integer.hpp:9
Definition: tile_fmha_traits.hpp:111
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:114
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:113
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:112
Definition: tile_fmha_traits.hpp:101
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:104
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:102
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:103
Definition: tile_fmha_traits.hpp:89
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:92
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:91
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:94
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:90
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:93
Definition: tile_fmha_traits.hpp:72
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:73
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:74
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:76
static constexpr index_t kMaxSplits
Definition: tile_fmha_traits.hpp:78
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:75
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:80
Definition: tile_fmha_traits.hpp:49
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:54
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:51
static constexpr bool kHasUnevenSplits
Definition: tile_fmha_traits.hpp:60
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:53
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:57
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: tile_fmha_traits.hpp:61
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:62
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:52
static constexpr bool kIsPagedKV
Definition: tile_fmha_traits.hpp:58
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:50
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:56
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:55
Definition: tile_fmha_traits.hpp:23
static constexpr bool kHasDropout
Definition: tile_fmha_traits.hpp:31
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:33
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:28
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:24
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:32
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:26
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:29
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:30
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:25
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:27