/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window_base.hpp Source File
tile_window_base.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
18 
19 namespace ck_tile {
20 
29 template <typename TileWindowType_, typename BottomTensorView_, typename WindowLengths_>
31 {
32 
35  using BottomTensorDesc = typename BottomTensorView::TensorDesc;
37 
38  static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
39 
41  "wrong! lengths should be static");
42 
44 
45  CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
46  CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
47  CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
49 
50  CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
51  {
52  window_origin_ = new_window_origin;
53 
54  // Delegate to child if it implements extra logic
55  static_cast<TileWindowType_*>(this)->set_window_origin_extended(new_window_origin);
56  }
57  // Default no-op; can be overridden in child
59 
60  CK_TILE_DEVICE constexpr void
61  set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
62  {
63  bottom_tensor_view_.buf_.p_data_ = data;
64  }
65 
66  // move window-origin
68  {
69  window_origin_ += step;
70 
71  // Delegate to child if it implements extra movement logic
72  static_cast<TileWindowType_*>(this)->move_extended(step);
73  }
74 
75  // Default no-op; can be overridden in child
77 
78  // origin ([x0', x1', ...]) of window on bottom tensor
80 
82 
83  // this is the bottom tensor view
84  // [x0', x1', ...] ==> [offset]
86 };
87 
88 template <typename TileWindowType_,
89  typename BottomTensorView_,
90  typename WindowLengths_,
91  typename StaticTileDistribution_>
93  : public tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>
94 {
97 
98  using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
99 
100  static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
101 
102  static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
103  static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
104 
106  // using BottomTensorIndex = array<index_t, TileWindowBase::NDimBottomTensor>;
107 
110 
113 
114  static_assert(TileDstr::is_static(), "wrong!");
115  static_assert(TileWindowBase::NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
116  "wrong! inconsistent # of diemsnions");
117 
118  CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
119  CK_TILE_HOST_DEVICE void init_raw() { this->bottom_tensor_view_.init_raw(); }
120 
122  {
123  return TileDstr::is_static();
124  }
125 
126  // move thread's window adaptor coordinate and bottom tensor coordinate
127  // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
128  template <typename ATopIndex>
130  WindowAdaptorCoord& window_adaptor_thread_coord,
131  BottomTensorCoord& bottom_tensor_thread_coord,
132  const ATopIndex& idx_diff_adaptor_top) const
133  {
134  array<index_t, TileWindowBase::NDimBottomTensor> idx_diff_adaptor_bottom;
135 
136  move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
137  window_adaptor_thread_coord,
138  idx_diff_adaptor_top,
139  idx_diff_adaptor_bottom);
140 
141  move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
142  bottom_tensor_thread_coord,
143  idx_diff_adaptor_bottom);
144  }
145 
146  struct Traits
147  {
148  public:
149  static constexpr index_t PackedSize =
151 
152  static constexpr auto get_vector_dim_y_scalar_per_vector()
153  {
154  const auto [ys_vector_lengths, ys_vector_strides] =
156 
157  index_t VectorDimY_ = 0;
158  index_t ScalarPerVector_ = 1;
159 
160  for(index_t i = 0; i < NDimY; ++i)
161  {
162  if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
163  {
164  ScalarPerVector_ = ys_vector_lengths[i];
165  VectorDimY_ = i;
166  }
167  }
168 
169  return make_tuple(VectorDimY_, ScalarPerVector_);
170  }
171 
172  static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
173  static constexpr index_t ScalarPerVector =
174  get_vector_dim_y_scalar_per_vector().template at<1>();
175  using vector_t =
177 
178  static constexpr auto scalars_per_access_ = [] {
179  constexpr auto scalars_per_access_arr = generate_array(
180  [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
181 
183  constexpr auto NDimY_ = NDimY;
184 
185  return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
186  }();
187 
188  static constexpr auto get_space_filling_curve()
189  {
190  constexpr auto thread_tensor_lengths_ys =
191  to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
192 
193  // FIXME: need logic to judge dim access order
194  using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
195 
196  return space_filling_curve<decltype(thread_tensor_lengths_ys),
197  DimAccessOrder,
198  decltype(scalars_per_access_),
199  false >{};
200  }
201 
202  using SFC_Ys = decltype(get_space_filling_curve());
203 
204  static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
205 
206  static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
207  };
208 
209  // return vector dimension among [y0, y1, ...]
211  {
212  // bottom tensor top dimension vector lengths and strides
213  const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
214  TileWindowBase::BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
215 
216  // window vector lengths/strides
217  const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
218  const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
219 
220  // window adaptor [p0, p1, ..., y0, y1, ...]
221  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
222  -1};
223  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
224  -1};
225 
226  constexpr auto window_adaptor_bottom_dims =
227  WindowAdaptor::get_bottom_dimension_hidden_ids();
228 
229  set_container_subset(window_adaptor_vector_lengths,
230  window_adaptor_bottom_dims,
231  window_adaptor_bottom_dim_vector_lengths);
232  set_container_subset(window_adaptor_vector_strides,
233  window_adaptor_bottom_dims,
234  window_adaptor_bottom_dim_vector_strides);
235 
236  const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
237  WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
238  window_adaptor_vector_lengths, window_adaptor_vector_strides);
239 
240  // [y0, y1, ...]
241  constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
243  1>::type{};
244 
245  return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
246  get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
247  }
248 
249  CK_TILE_DEVICE constexpr auto get_num_of_access() const { return Traits::NumAccess; }
250  // Tile tensor distribution, which contains:
251  // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
252  // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
254 };
255 
256 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1112
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1052
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:15
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:97
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
Definition: sequence.hpp:284
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:299
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Definition: numeric.hpp:81
Definition: space_filling_curve.hpp:20
Definition: debug.hpp:67
This class provides description of tile windowed view on the device memory.
Definition: tile_window_base.hpp:31
static constexpr index_t NDimBottomTensor
Definition: tile_window_base.hpp:38
BottomTensorView bottom_tensor_view_
Definition: tile_window_base.hpp:85
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_window_base.hpp:61
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_window_base.hpp:36
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_window_base.hpp:35
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
BottomTensorIndex window_origin_
Definition: tile_window_base.hpp:79
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_window_base.hpp:50
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
CK_TILE_DEVICE void move_extended(const BottomTensorIndex &)
Definition: tile_window_base.hpp:76
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_window_base.hpp:48
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window_base.hpp:67
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_window_base.hpp:33
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_window_base.hpp:34
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_window_base.hpp:43
WindowLengths window_lengths_
Definition: tile_window_base.hpp:81
CK_TILE_DEVICE void set_window_origin_extended(const BottomTensorIndex &)
Definition: tile_window_base.hpp:58
Definition: tile_window_base.hpp:147
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_window_base.hpp:202
static constexpr auto get_space_filling_curve()
Definition: tile_window_base.hpp:188
static constexpr index_t ScalarPerVector
Definition: tile_window_base.hpp:173
static constexpr index_t PackedSize
Definition: tile_window_base.hpp:149
static constexpr auto scalars_per_access_
Definition: tile_window_base.hpp:178
static constexpr index_t VectorDimY
Definition: tile_window_base.hpp:172
static constexpr index_t NumAccess
Definition: tile_window_base.hpp:204
static constexpr auto get_vector_dim_y_scalar_per_vector()
Definition: tile_window_base.hpp:152
Definition: tile_window_base.hpp:94
static constexpr index_t NDimY
Definition: tile_window_base.hpp:103
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_window_base.hpp:100
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_window_base.hpp:249
decltype(make_tensor_coordinate(typename TileWindowBase::BottomTensorDesc{}, typename TileWindowBase::BottomTensorIndex{})) BottomTensorCoord
Definition: tile_window_base.hpp:112
static constexpr index_t NDimP
Definition: tile_window_base.hpp:102
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_window_base.hpp:95
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_window_base.hpp:121
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_window_base.hpp:119
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_window_base.hpp:109
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_window_base.hpp:98
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_window_base.hpp:105
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_window_base.hpp:129
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_window_base.hpp:118
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_window_base.hpp:210
TileDstr tile_dstr_
Definition: tile_window_base.hpp:253
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10