/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/page_block_navigator.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/page_block_navigator.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/page_block_navigator.hpp Source File
page_block_navigator.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // assume that we have only 1 page-block/tensor view
12 template <typename TensorView>
14 {
15  using DataType = typename TensorView::DataType;
17 
18  CK_TILE_HOST_DEVICE constexpr TrivialPageBlockNavigator(const TensorView& tensor_view_)
19  : tensor_view(tensor_view_)
20  {
21  }
22 
23  template <typename WindowLengths>
24  CK_TILE_HOST_DEVICE constexpr auto make_tile_window(const WindowLengths& window_lengths,
25  const WindowOrigin& window_origin) const
26  {
27  return make_tuple(/*block_index=*/0,
28  ck_tile::make_tile_window(tensor_view, window_lengths, window_origin));
29  }
30 
31  template <typename WindowLengths, typename TileDistribution>
32  CK_TILE_HOST_DEVICE constexpr auto
33  make_tile_window(const WindowLengths& window_lengths,
34  const WindowOrigin& window_origin,
35  const TileDistribution& tile_distribution) const
36  {
37  return make_tuple(
38  /*block_index=*/0,
40  tensor_view, window_lengths, window_origin, tile_distribution));
41  }
42 
43  template <typename TileWindow>
45  move_tile_window(index_t /*block_index*/,
46  TileWindow& tile_window,
48  {
49  ck_tile::move_tile_window(tile_window, step);
50 
51  return /*block_index=*/0;
52  }
53 
54  template <typename TileWindow>
56  move_tile_window(index_t /*block_index*/,
57  TileWindow& tile_window,
59  index_t /*id*/) const
60  {
61 
62  ck_tile::move_tile_window(tile_window, step);
63  return 0;
64  }
65 
66  template <typename TileWindow>
68  prefetch_table_id(index_t /*block_index*/,
69  TileWindow /*tile_window*/,
70  const typename remove_cvref_t<TileWindow>::BottomTensorIndex& /*step*/) const
71  {
72  return -1;
73  }
74 
75  CK_TILE_HOST_DEVICE static constexpr WindowOrigin
76  to_local_window_origin(const WindowOrigin& global_window_origin)
77  {
78  return global_window_origin;
79  }
80 
81  CK_TILE_HOST_DEVICE static constexpr WindowOrigin
82  to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin)
83  {
84  return local_window_origin;
85  }
86 
87  private:
88  TensorView tensor_view;
89 };
90 
91 // default page-block navigator, assume that tensor view size is same as page-block size or smaller
92 // if tile window on last page-block
93 template <typename DataType_, index_t VirtualDim, typename TensorView>
95 {
96  using DataType = DataType_;
97  static_assert(std::is_same_v<DataType, typename TensorView::DataType>);
98  static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window");
100 
102  long_index_t block_stride_,
103  long_index_t fixed_offset_,
104  const int32_t* physical_block_indices_,
105  index_t num_blocks_,
106  index_t page_block_size_,
107  const TensorView& complete_view_,
108  const TensorView& last_view_)
109  : physical_blocks(reinterpret_cast<DataType*>(physical_blocks_)),
110  block_stride(block_stride_),
111  fixed_offset(fixed_offset_),
112  physical_block_indices(physical_block_indices_),
113  num_blocks(num_blocks_),
114  page_block_size(page_block_size_),
115  complete_view(complete_view_),
116  last_view(last_view_)
117  {
118  }
119 
120  template <typename WindowLengths>
121  CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
122  const WindowOrigin& window_origin) const
123  {
124  const index_t block_index = get_block_index(window_origin);
125  const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
126 
127  auto new_tile_window =
128  ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
129  window_lengths,
130  local_window_origin);
131  new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
132 
133  return make_tuple(block_index, new_tile_window);
134  }
135 
136  template <typename WindowLengths, typename TileDistribution>
137  CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
138  const WindowOrigin& window_origin,
139  const TileDistribution& tile_distribution) const
140  {
141  const index_t block_index = get_block_index(window_origin);
142  const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
143 
144  auto new_tile_window =
145  ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
146  window_lengths,
147  local_window_origin,
149  new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
150 
151  return make_tuple(block_index, new_tile_window);
152  }
153 
154  template <typename TileWindow>
157  TileWindow& tile_window,
158  const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
159  {
160 
161  ck_tile::move_tile_window(tile_window, step);
162 
163  const WindowOrigin global_window_origin =
164  to_global_window_origin(block_index, tile_window.get_window_origin());
165  const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
166 
167  const index_t new_block_index = get_block_index(global_window_origin);
169  tile_window.bottom_tensor_view_.desc_ =
170  (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
171  tile_window.set_window_origin(local_window_origin);
172  tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
173 
174  return new_block_index;
175  }
176 
177  template <typename TileWindow>
180  TileWindow& tile_window,
182  index_t id) const
183  {
184  ck_tile::move_tile_window(tile_window, step);
185 
186  const WindowOrigin global_window_origin =
187  to_global_window_origin(block_index, tile_window.get_window_origin());
188  const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
189 
190  const index_t new_block_index = get_block_index(global_window_origin);
192  tile_window.bottom_tensor_view_.desc_ =
193  (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
194  tile_window.set_window_origin(local_window_origin);
195  if(id >= 0)
196  tile_window.set_bottom_tensor_view_data_ptr(physical_blocks + id * block_stride +
197  fixed_offset);
198  else
199  tile_window.set_bottom_tensor_view_data_ptr(nullptr);
200 
201  return new_block_index;
202  }
203 
204  template <typename TileWindow>
207  TileWindow& tile_window,
208  const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
209  {
210  auto local_tile_window = tile_window; // not affect origin window
211  ck_tile::move_tile_window(local_tile_window, step);
212 
213  const WindowOrigin global_window_origin =
214  to_global_window_origin(block_index, local_tile_window.get_window_origin());
215  const index_t new_block_index = get_block_index(global_window_origin);
216 
217  if(new_block_index < num_blocks)
218  {
219  return physical_block_indices[new_block_index];
220  }
221  else
222  {
223  return -1;
224  }
225  }
226 
227  CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
228  {
229  return block_index == num_blocks - 1;
230  }
231 
232  template <typename TileWindow>
234  const TileWindow& tile_window) const
235  {
236  const index_t origin = tile_window.get_window_origin().at(number<VirtualDim>{});
237  const index_t length = tile_window.get_window_lengths().at(number<VirtualDim>{});
238  return (block_index < num_blocks - 1) && (page_block_size < origin + length);
239  }
240 
241  template <typename TileWindow>
243  move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const
244  {
245  const multi_index<2> step = [&]() {
246  const index_t origin_diff = (block_index - new_block_index) * page_block_size;
247  if constexpr(VirtualDim == 0)
248  {
249  return make_multi_index(origin_diff, 0);
250  }
251  else
252  {
253  return make_multi_index(0, origin_diff);
254  }
255  }();
256 
258  tile_window.bottom_tensor_view_.desc_ =
259  (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
260  tile_window.set_window_origin(tile_window.get_window_origin() + step);
261  tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
262  }
263 
265  to_local_window_origin(const WindowOrigin& global_window_origin) const
266  {
267  if constexpr(VirtualDim == 0)
268  {
269  const index_t length = global_window_origin.at(number<0>{});
270  const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
271  return make_multi_index(length - page_block_size * num_complete_blocks,
272  global_window_origin.at(number<1>{}));
273  }
274  else
275  {
276  const index_t length = global_window_origin.at(number<1>{});
277  const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
278  return make_multi_index(global_window_origin.at(number<0>{}),
279  length - page_block_size * num_complete_blocks);
280  }
281  }
282 
284  to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const
285  {
286  if constexpr(VirtualDim == 0)
287  {
288  return make_multi_index(block_index * page_block_size +
289  local_window_origin.at(number<0>{}),
290  local_window_origin.at(number<1>{}));
291  }
292  else
293  {
294  return make_multi_index(local_window_origin.at(number<0>{}),
295  block_index * page_block_size +
296  local_window_origin.at(number<1>{}));
297  }
298  }
299 
300  private:
302  DataType* get_block_ptr(index_t block_index) const
303  {
304  if(block_index < num_blocks)
305  {
306  return physical_blocks + physical_block_indices[block_index] * block_stride +
307  fixed_offset;
308  }
309  else
310  {
311  return nullptr;
312  }
313  }
314 
315  CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const
316  {
317  return integer_divide_floor(global_window_origin.at(number<VirtualDim>{}), page_block_size);
318  }
319 
320  DataType* physical_blocks;
321  long_index_t block_stride;
322  long_index_t fixed_offset;
323 
324  const int32_t* physical_block_indices;
325  index_t num_blocks;
326  index_t page_block_size;
327 
328  TensorView complete_view;
329  TensorView last_view;
330 };
331 
332 template <typename TensorView>
334 {
336 }
337 
338 template <typename DataType, index_t VirtualDim, typename TensorView>
340  long_index_t block_stride,
341  long_index_t fixed_offset,
342  const int32_t* physical_block_indices,
343  index_t num_blocks,
344  index_t page_block_size,
345  const TensorView& complete_view,
346  const TensorView& last_view)
347 {
349  block_stride,
350  fixed_offset,
351  physical_block_indices,
352  num_blocks,
353  page_block_size,
354  complete_view,
355  last_view);
356 }
357 
358 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:333
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename copy_const< From, To >::type copy_const_t
Definition: type_traits.hpp:41
constexpr CK_TILE_HOST_DEVICE auto integer_divide_floor(X x, Y y)
Definition: math.hpp:143
constexpr CK_TILE_HOST_DEVICE auto make_multi_index(Xs &&... xs)
Definition: multi_index.hpp:20
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
signed int int32_t
Definition: stdint.h:123
Definition: page_block_navigator.hpp:95
CK_TILE_HOST_DEVICE void move_to_block(index_t block_index, TileWindow &tile_window, index_t new_block_index) const
Definition: page_block_navigator.hpp:243
CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
Definition: page_block_navigator.hpp:227
multi_index< 2 > WindowOrigin
Definition: page_block_navigator.hpp:99
CK_TILE_HOST_DEVICE index_t move_tile_window(index_t block_index, TileWindow &tile_window, const typename remove_cvref_t< TileWindow >::BottomTensorIndex &step, index_t id) const
Definition: page_block_navigator.hpp:179
constexpr CK_TILE_HOST_DEVICE PageBlockNavigator(copy_const_t< DataType, void > *physical_blocks_, long_index_t block_stride_, long_index_t fixed_offset_, const int32_t *physical_block_indices_, index_t num_blocks_, index_t page_block_size_, const TensorView &complete_view_, const TensorView &last_view_)
Definition: page_block_navigator.hpp:101
CK_TILE_HOST_DEVICE WindowOrigin to_local_window_origin(const WindowOrigin &global_window_origin) const
Definition: page_block_navigator.hpp:265
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths &window_lengths, const WindowOrigin &window_origin) const
Definition: page_block_navigator.hpp:121
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths &window_lengths, const WindowOrigin &window_origin, const TileDistribution &tile_distribution) const
Definition: page_block_navigator.hpp:137
DataType_ DataType
Definition: page_block_navigator.hpp:96
CK_TILE_HOST_DEVICE index_t move_tile_window(index_t block_index, TileWindow &tile_window, const typename remove_cvref_t< TileWindow >::BottomTensorIndex &step) const
Definition: page_block_navigator.hpp:156
CK_TILE_HOST_DEVICE index_t prefetch_table_id(index_t block_index, TileWindow &tile_window, const typename remove_cvref_t< TileWindow >::BottomTensorIndex &step) const
Definition: page_block_navigator.hpp:206
CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index, const TileWindow &tile_window) const
Definition: page_block_navigator.hpp:233
CK_TILE_HOST_DEVICE WindowOrigin to_global_window_origin(index_t block_index, const WindowOrigin &local_window_origin) const
Definition: page_block_navigator.hpp:284
Definition: page_block_navigator.hpp:14
typename TensorView::DataType DataType
Definition: page_block_navigator.hpp:15
CK_TILE_HOST_DEVICE index_t move_tile_window(index_t, TileWindow &tile_window, const typename remove_cvref_t< TileWindow >::BottomTensorIndex &step, index_t) const
Definition: page_block_navigator.hpp:56
CK_TILE_HOST_DEVICE index_t prefetch_table_id(index_t, TileWindow, const typename remove_cvref_t< TileWindow >::BottomTensorIndex &) const
Definition: page_block_navigator.hpp:68
constexpr CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths &window_lengths, const WindowOrigin &window_origin, const TileDistribution &tile_distribution) const
Definition: page_block_navigator.hpp:33
constexpr CK_TILE_HOST_DEVICE TrivialPageBlockNavigator(const TensorView &tensor_view_)
Definition: page_block_navigator.hpp:18
multi_index< 2 > WindowOrigin
Definition: page_block_navigator.hpp:16
constexpr CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths &window_lengths, const WindowOrigin &window_origin) const
Definition: page_block_navigator.hpp:24
static constexpr CK_TILE_HOST_DEVICE WindowOrigin to_local_window_origin(const WindowOrigin &global_window_origin)
Definition: page_block_navigator.hpp:76
static constexpr CK_TILE_HOST_DEVICE WindowOrigin to_global_window_origin(index_t, const WindowOrigin &local_window_origin)
Definition: page_block_navigator.hpp:82
static CK_TILE_HOST_DEVICE index_t move_tile_window(index_t, TileWindow &tile_window, const typename remove_cvref_t< TileWindow >::BottomTensorIndex &step)
Definition: page_block_navigator.hpp:45
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
constexpr CK_TILE_HOST_DEVICE auto & at(index_t i)
Definition: array.hpp:110
Definition: integral_constant.hpp:13
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:72