include/ck_tile/ops/fmha/block/page_block_navigator.hpp Source File

include/ck_tile/ops/fmha/block/page_block_navigator.hpp Source File#

Composable Kernel: include/ck_tile/ops/fmha/block/page_block_navigator.hpp Source File
page_block_navigator.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // assume that we have only 1 page-block/tensor view
12 template <typename TensorView>
14 {
15  using DataType = typename TensorView::DataType;
17 
18  CK_TILE_HOST_DEVICE constexpr TrivialPageBlockNavigator(const TensorView& tensor_view_)
19  : tensor_view(tensor_view_)
20  {
21  }
22 
23  template <typename WindowLengths>
24  CK_TILE_HOST_DEVICE constexpr auto make_tile_window(const WindowLengths& window_lengths,
25  const WindowOrigin& window_origin) const
26  {
27  return make_tuple(/*block_index=*/0,
28  ck_tile::make_tile_window(tensor_view, window_lengths, window_origin));
29  }
30 
31  template <typename WindowLengths, typename TileDistribution>
32  CK_TILE_HOST_DEVICE constexpr auto
33  make_tile_window(const WindowLengths& window_lengths,
34  const WindowOrigin& window_origin,
35  const TileDistribution& tile_distribution) const
36  {
37  return make_tuple(
38  /*block_index=*/0,
40  tensor_view, window_lengths, window_origin, tile_distribution));
41  }
42 
43  template <typename TileWindow>
45  move_tile_window(index_t /*block_index*/,
46  TileWindow& tile_window,
48  {
49  ck_tile::move_tile_window(tile_window, step);
50 
51  return /*block_index=*/0;
52  }
53 
54  CK_TILE_HOST_DEVICE static constexpr WindowOrigin
55  to_local_window_origin(const WindowOrigin& global_window_origin)
56  {
57  return global_window_origin;
58  }
59 
60  CK_TILE_HOST_DEVICE static constexpr WindowOrigin
61  to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin)
62  {
63  return local_window_origin;
64  }
65 
66  private:
67  TensorView tensor_view;
68 };
69 
70 // default page-block navigator, assume that tensor view size is same as page-block size or smaller
71 // if tile window on last page-block
72 template <typename DataType_, index_t VirtualDim, typename TensorView>
74 {
75  using DataType = DataType_;
76  static_assert(std::is_same_v<DataType, typename TensorView::DataType>);
77  static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window");
79 
81  long_index_t block_stride_,
82  long_index_t fixed_offset_,
83  const int32_t* physical_block_indices_,
84  index_t num_blocks_,
85  index_t page_block_size_,
86  const TensorView& complete_view_,
87  const TensorView& last_view_)
88  : physical_blocks(reinterpret_cast<DataType*>(physical_blocks_)),
89  block_stride(block_stride_),
90  fixed_offset(fixed_offset_),
91  physical_block_indices(physical_block_indices_),
92  num_blocks(num_blocks_),
93  page_block_size(page_block_size_),
94  complete_view(complete_view_),
95  last_view(last_view_)
96  {
97  }
98 
99  template <typename WindowLengths>
100  CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
101  const WindowOrigin& window_origin) const
102  {
103  const index_t block_index = get_block_index(window_origin);
104  const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
105 
106  auto new_tile_window =
107  ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
108  window_lengths,
109  local_window_origin);
110  new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
111 
112  return make_tuple(block_index, new_tile_window);
113  }
114 
115  template <typename WindowLengths, typename TileDistribution>
116  CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
117  const WindowOrigin& window_origin,
118  const TileDistribution& tile_distribution) const
119  {
120  const index_t block_index = get_block_index(window_origin);
121  const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
122 
123  auto new_tile_window =
124  ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
125  window_lengths,
126  local_window_origin,
128  new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
129 
130  return make_tuple(block_index, new_tile_window);
131  }
132 
133  template <typename TileWindow>
136  TileWindow& tile_window,
137  const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
138  {
139 
140  ck_tile::move_tile_window(tile_window, step);
141 
142  const WindowOrigin global_window_origin =
143  to_global_window_origin(block_index, tile_window.get_window_origin());
144  const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
145 
146  const index_t new_block_index = get_block_index(global_window_origin);
148  tile_window.bottom_tensor_view_.desc_ =
149  (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
150  tile_window.set_window_origin(local_window_origin);
151  tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
152 
153  return new_block_index;
154  }
155 
156  CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
157  {
158  return block_index == num_blocks - 1;
159  }
160 
161  template <typename TileWindow>
163  const TileWindow& tile_window) const
164  {
165  const index_t origin = tile_window.get_window_origin().at(number<VirtualDim>{});
166  const index_t length = tile_window.get_window_lengths().at(number<VirtualDim>{});
167  return (block_index < num_blocks - 1) && (page_block_size < origin + length);
168  }
169 
170  template <typename TileWindow>
172  move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const
173  {
174  const multi_index<2> step = [&]() {
175  const index_t origin_diff = (block_index - new_block_index) * page_block_size;
176  if constexpr(VirtualDim == 0)
177  {
178  return make_multi_index(origin_diff, 0);
179  }
180  else
181  {
182  return make_multi_index(0, origin_diff);
183  }
184  }();
185 
187  tile_window.bottom_tensor_view_.desc_ =
188  (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
189  tile_window.set_window_origin(tile_window.get_window_origin() + step);
190  tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
191  }
192 
194  to_local_window_origin(const WindowOrigin& global_window_origin) const
195  {
196  if constexpr(VirtualDim == 0)
197  {
198  const index_t length = global_window_origin.at(number<0>{});
199  const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
200  return make_multi_index(length - page_block_size * num_complete_blocks,
201  global_window_origin.at(number<1>{}));
202  }
203  else
204  {
205  const index_t length = global_window_origin.at(number<1>{});
206  const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
207  return make_multi_index(global_window_origin.at(number<0>{}),
208  length - page_block_size * num_complete_blocks);
209  }
210  }
211 
213  to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const
214  {
215  if constexpr(VirtualDim == 0)
216  {
217  return make_multi_index(block_index * page_block_size +
218  local_window_origin.at(number<0>{}),
219  local_window_origin.at(number<1>{}));
220  }
221  else
222  {
223  return make_multi_index(local_window_origin.at(number<0>{}),
224  block_index * page_block_size +
225  local_window_origin.at(number<1>{}));
226  }
227  }
228 
229  private:
231  DataType* get_block_ptr(index_t block_index) const
232  {
233  if(block_index < num_blocks)
234  {
235  return physical_blocks + physical_block_indices[block_index] * block_stride +
236  fixed_offset;
237  }
238  else
239  {
240  return nullptr;
241  }
242  }
243 
244  CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const
245  {
246  return integer_divide_floor(global_window_origin.at(number<VirtualDim>{}), page_block_size);
247  }
248 
249  DataType* physical_blocks;
250  long_index_t block_stride;
251  long_index_t fixed_offset;
252 
253  const int32_t* physical_block_indices;
254  index_t num_blocks;
255  index_t page_block_size;
256 
257  TensorView complete_view;
258  TensorView last_view;
259 };
260 
261 template <typename TensorView>
263 {
265 }
266 
267 template <typename DataType, index_t VirtualDim, typename TensorView>
269  long_index_t block_stride,
270  long_index_t fixed_offset,
271  const int32_t* physical_block_indices,
272  index_t num_blocks,
273  index_t page_block_size,
274  const TensorView& complete_view,
275  const TensorView& last_view)
276 {
278  block_stride,
279  fixed_offset,
280  physical_block_indices,
281  num_blocks,
282  page_block_size,
283  complete_view,
284  last_view);
285 }
286 
287 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:262
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
typename copy_const< From, To >::type copy_const_t
Definition: type_traits.hpp:40
constexpr CK_TILE_HOST_DEVICE auto integer_divide_floor(X x, Y y)
Definition: math.hpp:143
constexpr CK_TILE_HOST_DEVICE auto make_multi_index(Xs &&... xs)
Definition: multi_index.hpp:20
int64_t long_index_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
Definition: page_block_navigator.hpp:74
CK_TILE_HOST_DEVICE void move_to_block(index_t block_index, TileWindow &tile_window, index_t new_block_index) const
Definition: page_block_navigator.hpp:172
CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
Definition: page_block_navigator.hpp:156
multi_index< 2 > WindowOrigin
Definition: page_block_navigator.hpp:78
constexpr CK_TILE_HOST_DEVICE PageBlockNavigator(copy_const_t< DataType, void > *physical_blocks_, long_index_t block_stride_, long_index_t fixed_offset_, const int32_t *physical_block_indices_, index_t num_blocks_, index_t page_block_size_, const TensorView &complete_view_, const TensorView &last_view_)
Definition: page_block_navigator.hpp:80
CK_TILE_HOST_DEVICE WindowOrigin to_local_window_origin(const WindowOrigin &global_window_origin) const
Definition: page_block_navigator.hpp:194
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths &window_lengths, const WindowOrigin &window_origin) const
Definition: page_block_navigator.hpp:100
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths &window_lengths, const WindowOrigin &window_origin, const TileDistribution &tile_distribution) const
Definition: page_block_navigator.hpp:116
DataType_ DataType
Definition: page_block_navigator.hpp:75
CK_TILE_HOST_DEVICE index_t move_tile_window(index_t block_index, TileWindow &tile_window, const typename remove_cvref_t< TileWindow >::BottomTensorIndex &step) const
Definition: page_block_navigator.hpp:135
CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index, const TileWindow &tile_window) const
Definition: page_block_navigator.hpp:162
CK_TILE_HOST_DEVICE WindowOrigin to_global_window_origin(index_t block_index, const WindowOrigin &local_window_origin) const
Definition: page_block_navigator.hpp:213
Definition: page_block_navigator.hpp:14
typename TensorView::DataType DataType
Definition: page_block_navigator.hpp:15
constexpr CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths &window_lengths, const WindowOrigin &window_origin, const TileDistribution &tile_distribution) const
Definition: page_block_navigator.hpp:33
constexpr CK_TILE_HOST_DEVICE TrivialPageBlockNavigator(const TensorView &tensor_view_)
Definition: page_block_navigator.hpp:18
multi_index< 2 > WindowOrigin
Definition: page_block_navigator.hpp:16
constexpr CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths &window_lengths, const WindowOrigin &window_origin) const
Definition: page_block_navigator.hpp:24
static constexpr CK_TILE_HOST_DEVICE WindowOrigin to_local_window_origin(const WindowOrigin &global_window_origin)
Definition: page_block_navigator.hpp:55
static constexpr CK_TILE_HOST_DEVICE WindowOrigin to_global_window_origin(index_t, const WindowOrigin &local_window_origin)
Definition: page_block_navigator.hpp:61
static CK_TILE_HOST_DEVICE index_t move_tile_window(index_t, TileWindow &tile_window, const typename remove_cvref_t< TileWindow >::BottomTensorIndex &step)
Definition: page_block_navigator.hpp:45
Definition: array.hpp:24
constexpr CK_TILE_HOST_DEVICE auto & at(index_t i)
Definition: array.hpp:91
Definition: integral_constant.hpp:13
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:72