/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp Source File
block_fmha_fwd_appendkv_pipeline.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 template <typename Problem_, typename Policy_ = BlockFmhaFwdAppendKVPipelineDefaultPolicy>
14 {
17  using QDataType = typename Problem::QDataType;
18  using KDataType = typename Problem::KDataType;
19  using VDataType = typename Problem::VDataType;
20 
21  using VLayout = typename Problem::VLayout;
22 
23  static constexpr index_t kBlockSize = Problem::kBlockSize;
24 
25  static constexpr index_t kM0 = Problem::kM0;
26  static constexpr index_t kN0 = Problem::kN0;
27  static constexpr index_t kK0 = Problem::kK0;
28  static constexpr index_t kN1 = Problem::kN1;
29 
30  static constexpr auto RotaryEnum = Problem::RotaryEnum;
31  static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
32 
33  static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
34  static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
35  static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
36  static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
37 
38  // last dimension vector length used to create tensor view(and decide buffer_load vector length)
39  // ... together with tensor distribution. tensor dist should able to overwrite this
40  static constexpr index_t kAlignmentQ =
41  kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
42  static constexpr index_t kAlignmentK =
43  kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
44  static constexpr index_t kAlignmentV = []() {
45  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
46  return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
47  else
48  return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
49  }();
50 
51  static constexpr index_t kBlockPerCu = []() {
52  if constexpr(Problem::kBlockPerCu != -1)
53  return Problem::kBlockPerCu;
54  else
55  {
56  if constexpr(kK0 <= 32)
57  {
58  return 2;
59  }
60  else if constexpr(kK0 <= 64)
61  {
62  return 3;
63  }
64  else if constexpr(kK0 <= 128)
65  {
66  return 2;
67  }
68  else if constexpr(kK0 <= 256)
69  {
70  return 1;
71  }
72  }
73  }();
74 
75  template <typename QDramBlockWindow,
76  typename KDramBlockWindow,
77  typename KPageBlockNavigator,
78  typename KnewDramBlockWindow,
79  typename VDramBlockWindow,
80  typename VPageBlockNavigator,
81  typename VnewDramBlockWindow,
82  typename QElementFunction,
83  typename KnewElementFunction,
84  typename VnewElementFunction,
85  typename QRotaryCosDramBlockWindow,
86  typename QRotarySinDramBlockWindow,
87  typename KnewRotaryCosDramBlockWindow,
88  typename KnewRotarySinDramBlockWindow>
90  operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
91  const QElementFunction& q_element_func,
92  KDramBlockWindow& k_dram_block_window, // N0*K0 tile
93  index_t i_page_block_k,
94  const KPageBlockNavigator& k_page_block_navigator,
95  const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
96  const KnewElementFunction& knew_element_func,
97  VDramBlockWindow& v_dram_block_window, // N1*N0 tile
98  index_t i_page_block_v,
99  const VPageBlockNavigator& v_page_block_navigator,
100  const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile
101  const VnewElementFunction& vnew_element_func,
102  const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window,
103  const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window,
104  const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window,
105  const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window,
106  index_t rotary_dim,
107  bool skip_rotate_q,
108  bool skip_rotate_append_kv) const
109  {
110  if(!skip_rotate_append_kv)
111  {
112  // append Knew to K
113  auto knew_window = make_tile_window(
114  knew_dram_block_window, Policy::template MakeKnewDramTileDistribution<Problem>());
115 
116  auto knew_tile = [&]() {
117  auto knew = load_tile(knew_window);
118  return tile_elementwise_in(knew_element_func, knew);
119  }();
120 
121  // optionally apply rotary embedding to Knew
122  if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
123  {
124  auto rotary_cos_window =
125  make_tile_window(knew_rotary_cos_dram_block_window,
126  Policy::template MakeRotaryCosSinTileDistribution<
127  Problem,
128  /*IsRotaryCosSinForQ=*/false>());
129 
130  auto rotary_sin_window =
131  make_tile_window(knew_rotary_sin_dram_block_window,
132  Policy::template MakeRotaryCosSinTileDistribution<
133  Problem,
134  /*IsRotaryCosSinForQ=*/false>());
135 
136  // We assume that each thread owns contiguous elements on head dimention. And we
137  // will use the distribution to enable/disable threads in order to override partial
138  // knew_tile content
139  auto [thread_start, thread_end] =
140  Policy::template GetKnewThreadRangeAlongK<Problem>();
141  ignore = thread_start;
142 
144  knew_window,
145  rotary_cos_window,
146  rotary_sin_window,
147  rotary_dim,
148  thread_end);
149  }
150 
151  store_tile(k_dram_block_window, knew_tile);
152 
153  // write tile to another block if nesscary
154  if constexpr(kIsPagedKV)
155  {
156  if(k_page_block_navigator.is_cross_block(i_page_block_k, k_dram_block_window))
157  {
158  k_page_block_navigator.move_to_block(
159  i_page_block_k, k_dram_block_window, i_page_block_k + 1);
160  store_tile(k_dram_block_window, knew_tile);
161  }
162  }
163 
164  // append Vnew to V
165  auto vnew_window = make_tile_window(
166  vnew_dram_block_window, Policy::template MakeVnewDramTileDistribution<Problem>());
167 
168  auto vnew_tile = [&]() {
169  auto vnew = load_tile(vnew_window);
170  return tile_elementwise_in(vnew_element_func, vnew);
171  }();
172 
173  store_tile(v_dram_block_window, vnew_tile);
174 
175  // write tile to another block if nesscary
176  if constexpr(kIsPagedKV)
177  {
178  if(v_page_block_navigator.is_cross_block(i_page_block_v, v_dram_block_window))
179  {
180  v_page_block_navigator.move_to_block(
181  i_page_block_v, v_dram_block_window, i_page_block_v + 1);
182  store_tile(v_dram_block_window, vnew_tile);
183  }
184  }
185  }
186 
187  if(!skip_rotate_q)
188  {
189  // optionally apply rotary embedding to Q
190  if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
191  {
192  auto q_window = make_tile_window(
193  q_dram_block_window, Policy::template MakeQDramTileDistribution<Problem>());
194 
195  auto q_tile = [&]() {
196  auto q = load_tile(q_window);
197  return tile_elementwise_in(q_element_func, q);
198  }();
199 
200  auto rotary_cos_window =
201  make_tile_window(q_rotary_cos_dram_block_window,
202  Policy::template MakeRotaryCosSinTileDistribution<
203  Problem,
204  /*IsRotaryCosSinForQ=*/true>());
205 
206  auto rotary_sin_window =
207  make_tile_window(q_rotary_sin_dram_block_window,
208  Policy::template MakeRotaryCosSinTileDistribution<
209  Problem,
210  /*IsRotaryCosSinForQ=*/true>());
211 
212  // We assume that each thread owns contiguous elements on head dimention. And we
213  // will use the distribution to enable/disable threads in order to override partial
214  // q_tile content
215  auto [thread_start, thread_end] = Policy::template GetQThreadRangeAlongK<Problem>();
216  ignore = thread_start;
217 
219  q_tile, q_window, rotary_cos_window, rotary_sin_window, rotary_dim, thread_end);
220 
221  store_tile(q_dram_block_window, q_tile);
222  }
223  }
224  }
225 
226  template <typename QDramBlockWindow,
227  typename KDramBlockWindow,
228  typename KPageBlockNavigator,
229  typename KnewDramBlockWindow,
230  typename VDramBlockWindow,
231  typename VPageBlockNavigator,
232  typename VnewDramBlockWindow,
233  typename QRotaryCosDramBlockWindow,
234  typename QRotarySinDramBlockWindow,
235  typename KnewRotaryCosDramBlockWindow,
236  typename KnewRotarySinDramBlockWindow>
238  operator()(QDramBlockWindow& q_dram_block_window,
239  KDramBlockWindow& k_dram_block_window,
240  index_t i_page_block_k,
241  const KPageBlockNavigator& k_page_block_navigator,
242  const KnewDramBlockWindow& knew_dram_block_window,
243  VDramBlockWindow& v_dram_block_window,
244  index_t i_page_block_v,
245  const VPageBlockNavigator& v_page_block_navigator,
246  const VnewDramBlockWindow& vnew_dram_block_window,
247  const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window,
248  const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window,
249  const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window,
250  const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window,
251  index_t rotary_dim,
252  bool skip_rotate_q,
253  bool skip_rotate_append_kv) const
254  {
255  return operator()(q_dram_block_window,
256  identity{},
257  k_dram_block_window,
258  i_page_block_k,
259  k_page_block_navigator,
260  knew_dram_block_window,
261  identity{},
262  v_dram_block_window,
263  i_page_block_v,
264  v_page_block_navigator,
265  vnew_dram_block_window,
266  identity{},
267  q_rotary_cos_dram_block_window,
268  q_rotary_sin_dram_block_window,
269  knew_rotary_cos_dram_block_window,
270  knew_rotary_sin_dram_block_window,
271  rotary_dim,
272  skip_rotate_q,
273  skip_rotate_append_kv);
274  }
275 };
276 
277 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
Definition: block_fmha_fwd_appendkv_pipeline.hpp:14
static constexpr bool kIsPagedKV
Definition: block_fmha_fwd_appendkv_pipeline.hpp:31
typename Problem::QDataType QDataType
Definition: block_fmha_fwd_appendkv_pipeline.hpp:17
typename Problem::KDataType KDataType
Definition: block_fmha_fwd_appendkv_pipeline.hpp:18
static constexpr index_t kAlignmentQ
Definition: block_fmha_fwd_appendkv_pipeline.hpp:40
static constexpr index_t kN1
Definition: block_fmha_fwd_appendkv_pipeline.hpp:28
static constexpr index_t kM0
Definition: block_fmha_fwd_appendkv_pipeline.hpp:25
typename Problem::VLayout VLayout
Definition: block_fmha_fwd_appendkv_pipeline.hpp:21
static constexpr bool kPadSeqLenK
Definition: block_fmha_fwd_appendkv_pipeline.hpp:34
static constexpr index_t kN0
Definition: block_fmha_fwd_appendkv_pipeline.hpp:26
static constexpr index_t kK0
Definition: block_fmha_fwd_appendkv_pipeline.hpp:27
static constexpr bool kPadHeadDimV
Definition: block_fmha_fwd_appendkv_pipeline.hpp:36
static constexpr index_t kBlockSize
Definition: block_fmha_fwd_appendkv_pipeline.hpp:23
remove_cvref_t< Policy_ > Policy
Definition: block_fmha_fwd_appendkv_pipeline.hpp:16
CK_TILE_HOST_DEVICE auto operator()(QDramBlockWindow &q_dram_block_window, const QElementFunction &q_element_func, KDramBlockWindow &k_dram_block_window, index_t i_page_block_k, const KPageBlockNavigator &k_page_block_navigator, const KnewDramBlockWindow &knew_dram_block_window, const KnewElementFunction &knew_element_func, VDramBlockWindow &v_dram_block_window, index_t i_page_block_v, const VPageBlockNavigator &v_page_block_navigator, const VnewDramBlockWindow &vnew_dram_block_window, const VnewElementFunction &vnew_element_func, const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window, const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window, const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window, const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window, index_t rotary_dim, bool skip_rotate_q, bool skip_rotate_append_kv) const
Definition: block_fmha_fwd_appendkv_pipeline.hpp:90
static constexpr auto RotaryEnum
Definition: block_fmha_fwd_appendkv_pipeline.hpp:30
static constexpr bool kPadHeadDimQ
Definition: block_fmha_fwd_appendkv_pipeline.hpp:35
CK_TILE_HOST_DEVICE auto operator()(QDramBlockWindow &q_dram_block_window, KDramBlockWindow &k_dram_block_window, index_t i_page_block_k, const KPageBlockNavigator &k_page_block_navigator, const KnewDramBlockWindow &knew_dram_block_window, VDramBlockWindow &v_dram_block_window, index_t i_page_block_v, const VPageBlockNavigator &v_page_block_navigator, const VnewDramBlockWindow &vnew_dram_block_window, const QRotaryCosDramBlockWindow &q_rotary_cos_dram_block_window, const QRotarySinDramBlockWindow &q_rotary_sin_dram_block_window, const KnewRotaryCosDramBlockWindow &knew_rotary_cos_dram_block_window, const KnewRotarySinDramBlockWindow &knew_rotary_sin_dram_block_window, index_t rotary_dim, bool skip_rotate_q, bool skip_rotate_append_kv) const
Definition: block_fmha_fwd_appendkv_pipeline.hpp:238
static constexpr bool kPadSeqLenQ
Definition: block_fmha_fwd_appendkv_pipeline.hpp:33
static constexpr index_t kAlignmentV
Definition: block_fmha_fwd_appendkv_pipeline.hpp:44
remove_cvref_t< Problem_ > Problem
Definition: block_fmha_fwd_appendkv_pipeline.hpp:15
typename Problem::VDataType VDataType
Definition: block_fmha_fwd_appendkv_pipeline.hpp:19
static constexpr index_t kBlockPerCu
Definition: block_fmha_fwd_appendkv_pipeline.hpp:51
static constexpr index_t kAlignmentK
Definition: block_fmha_fwd_appendkv_pipeline.hpp:42
static CK_TILE_HOST_DEVICE void apply(DistributedTensor &tile, OtherDramBlockWindow other_window, RotaryCosDramBlockWindow rotary_cos_window, RotarySinDramBlockWindow rotary_sin_window, index_t rotary_dim, index_t thread_end)
Definition: block_rotary_embedding.hpp:44
Definition: functional.hpp:86