include/ck_tile/core/tensor/tile_window.hpp Source File

include/ck_tile/core/tensor/tile_window.hpp Source File#

Composable Kernel: include/ck_tile/core/tensor/tile_window.hpp Source File
tile_window.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
18 
19 namespace ck_tile {
20 
32 template <typename BottomTensorView_,
33  typename WindowLengths_,
34  typename StaticTileDistribution_,
35  index_t NumCoord>
37 {
41 
42  using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
43  using BottomTensorDesc = typename BottomTensorView::TensorDesc;
44 
46 
47  static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
48  static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
49 
50  static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
51  static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
52 
53  static constexpr auto I0 = number<0>{};
54  static constexpr auto I1 = number<1>{};
55  static_assert(NumCoord == 1);
56 
57  // TODO: check WindowLengths and StaticTileDistribution are consistent
58 
60  "wrong! lengths should be static");
61  static_assert(TileDstr::is_static(), "wrong!");
62 
63  static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
64  "wrong! inconsistent # of diemsnions");
65 
68 
71 
74 
76  {
77  private:
78  static constexpr auto get_vector_dim_y_scalar_per_vector()
79  {
80  const auto [ys_vector_lengths, ys_vector_strides] =
83 
84  index_t VectorDimY_ = 0;
85  index_t ScalarPerVector_ = 1;
86 
87  for(index_t i = 0; i < NDimY; ++i)
88  {
89  if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
90  {
91  ScalarPerVector_ = ys_vector_lengths[i];
92  VectorDimY_ = i;
93  }
94  }
95 
96  return make_tuple(VectorDimY_, ScalarPerVector_);
97  }
98 
99  public:
100  static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
101  static constexpr index_t ScalarPerVector =
102  get_vector_dim_y_scalar_per_vector().template at<1>();
103 
104  // using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
105  // using vector_t = typename vector_type_t::type;
107 
108  private:
109  static constexpr auto scalars_per_access_ = [] {
110  constexpr auto scalars_per_access_arr = generate_array(
111  [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
112 
114  constexpr auto NDimY_ = NDimY;
115 
116  return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
117  }();
118 
119  static constexpr auto get_space_filling_curve()
120  {
121  constexpr auto tile_dstr = TileDstr{};
122 
123  constexpr auto thread_tensor_lengths_ys =
124  to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
125 
126  // FIXME: need logic to judge dim access order
127  using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
128 
129  return space_filling_curve<decltype(thread_tensor_lengths_ys),
130  DimAccessOrder,
131  decltype(scalars_per_access_)>{};
132  }
133 
134  public:
135  using SFC_Ys = decltype(get_space_filling_curve());
136 
137  static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
138 
139  static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
140  static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
141  };
142 
144 
146 
148  const BottomTensorView& bottom_tensor_view,
149  const WindowLengths& window_lengths,
150  const BottomTensorIndex& window_origin,
152  : bottom_tensor_view_{bottom_tensor_view},
153  window_lengths_{window_lengths},
154  window_origin_{window_origin},
157  {
158 #if 0 // debug
159  // TODO: this use more register for FA, but less register for GEMM
160  // need investigation
161  // only support warp-tile and block-tile
162  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
163 
164  WindowAdaptorCoord window_adaptor_thread_coord_tmp;
165 
166  if constexpr(NDimP == 1)
167  {
168  window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
170  }
171  else if constexpr(NDimP == 2)
172  {
173  window_adaptor_thread_coord_tmp =
175  AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
176  }
177 #else
178  // TODO: this use less register for FA, but more register for GEMM
179  // need investigation
180  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
184 #endif
185 
186  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
187  window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
188 
189  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
190  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
191 
192  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
193  // future load/store() calls (might allocate more registers)
194  using Traits = load_store_traits;
195  using SFC_Ys = typename Traits::SFC_Ys;
196 
197  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
198  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
199  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
200 
201  constexpr auto idx_diff_ys =
202  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
203 
204  constexpr auto idx_diff_ps_ys = container_concat(
205  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
206 
208  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
209 
210  pre_computed_coords_(iCoord) =
211  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
212  });
213  }
214 
216 
218  {
219  return TileDstr::is_static();
220  }
221 
222  CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
223 
224  CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
225 
226  CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
227 
228  CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
229 
230  CK_TILE_DEVICE constexpr void
231  set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
232  {
233  bottom_tensor_view_.buf_.p_data_ = data;
234  }
235 
236  // move thread's window adaptor coordinate and bottom tensor coordinate
237  // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
238  template <typename ATopIndex>
240  WindowAdaptorCoord& window_adaptor_thread_coord,
241  BottomTensorCoord& bottom_tensor_thread_coord,
242  const ATopIndex& idx_diff_adaptor_top) const
243  {
244  array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
245 
246  move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
247  window_adaptor_thread_coord,
248  idx_diff_adaptor_top,
249  idx_diff_adaptor_bottom);
250 
251  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
252  bottom_tensor_thread_coord,
253  idx_diff_adaptor_bottom);
254  }
255 
256  // return vector dimension among [y0, y1, ...]
258  {
259  // bottom tensor top dimension vector lengths and strides
260  const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
261  BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
262 
263  // window vector lengths/strides
264  const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
265  const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
266 
267  // window adaptor [p0, p1, ..., y0, y1, ...]
268  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
269  -1};
270  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
271  -1};
272 
273  constexpr auto window_adaptor_bottom_dims =
274  WindowAdaptor::get_bottom_dimension_hidden_ids();
275 
276  set_container_subset(window_adaptor_vector_lengths,
277  window_adaptor_bottom_dims,
278  window_adaptor_bottom_dim_vector_lengths);
279  set_container_subset(window_adaptor_vector_strides,
280  window_adaptor_bottom_dims,
281  window_adaptor_bottom_dim_vector_strides);
282 
283  const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
284  WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
285  window_adaptor_vector_lengths, window_adaptor_vector_strides);
286 
287  // [y0, y1, ...]
288  constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
290  1>::type{};
291 
292  return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
293  get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
294  }
295 
297 
298  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
301  {
302  constexpr auto tile_dstr = TileDstr{};
303  auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
304  load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
305  return dst_tensor;
306  }
307 
308  template <typename DistributedTensor,
309  index_t i_access_unsupport_ = -1,
310  bool oob_conditional_check = true>
311  CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
314  {
315  using Traits = load_store_traits;
316  using vector_t = typename Traits::vector_t;
317  using SFC_Ys = typename Traits::SFC_Ys;
318 
319  constexpr auto tile_dstr = TileDstr{};
320 
321  // loop over thread tensor space [y0, y1, ...]
322  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
324  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
325  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
326 
327  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
328  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
329 
330  // data index [y0, y1, ...]
331  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
332 
333  // read from bottom tensor
334  const vector_t vec_value =
335  get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
336  bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
337 #if 1
338  // write into distributed tensor
339  static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
340  constexpr auto idx_ys = generate_tuple(
341  [&](auto jj) {
342  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
343  : idx_ys_start[jj];
344  },
345  number<NDimY>{});
346 
347  constexpr index_t d =
348  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
349 
350  dst_tensor.get_thread_buffer().template at<d>() =
351  vec_value.template get_as<DataType>()[j];
352  });
353 #else
354  constexpr index_t d =
355  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
356  static_assert(d % Traits::ScalarPerVector == 0);
357 
358  dst_tensor.get_thread_buffer().template get_as<vector_t>()(
359  number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
360 #endif
361  // move thread coordinate
362  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
363  {
364  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
365 
366  constexpr auto idx_diff_ps_ys = container_concat(
367  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
368  idx_diff_ys);
369 
371  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
372  }
373  });
374  });
375  }
376 
377  template <typename DstTile,
378  index_t i_access_unsupport_ = -1,
379  bool oob_conditional_check = true,
380  bool pre_nop = false>
381  CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
384  bool_constant<pre_nop> = {}) const
385  {
386  using Traits = load_store_traits;
387 
388  // using vector_type_t = typename Traits::vector_type_t;
389  using vector_t = typename Traits::vector_t;
390  using SFC_Ys = typename Traits::SFC_Ys;
391  static constexpr index_t YElementSize =
392  TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
393  static_assert(YElementSize % Traits::ScalarPerVector == 0);
394  using vectorized_tbuf = array<vector_t, YElementSize / Traits::ScalarPerVector>;
395  // StaticBuffer<address_space_enum::vgpr,
396  // vector_t,
397  // YElementSize / Traits::ScalarPerVector,
398  // true>;
399 
400  constexpr auto tile_dstr = TileDstr{};
401 
402  auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
403 
404  // loop over thread tensor space [y0, y1, ...]
405  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
407  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
408  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
409 
410  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
411  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
412  constexpr auto pre_nop_ = [&]() {
413  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
414  return bool_constant<true>{};
415  else
416  return bool_constant<false>{};
417  }();
418 
419  // data index [y0, y1, ...]
420  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
421  constexpr index_t d =
422  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
423  static_assert(d % Traits::ScalarPerVector == 0);
424 
425  get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
426  dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
427  bottom_tensor_thread_coord,
428  0 ,
429  bool_constant<oob_conditional_check>{},
430  pre_nop_);
431 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
432  CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
433  asm volatile(
434  ""); // this is starting from rocm-6.2, but same sympton, reuse this flag
435 #endif
436  // move thread coordinate
437  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
438  {
439  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
440 
441  constexpr auto idx_diff_ps_ys = container_concat(
442  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
443  idx_diff_ys);
444 
446  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
447  }
448  });
449  });
450  }
451 
452  // TODO: currently async load only implemented in inline asm
453  template <typename LdsTileWindow_,
454  index_t i_access_unsupport_ = -1,
455  bool oob_conditional_check = true,
456  bool pre_nop = false>
457  CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
460  bool_constant<pre_nop> = {}) const
461  {
462  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
463  // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
464  using LdsDataType = typename LdsTileWindow::DataType;
465  // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
466 
467  // issues * warps * lanes
468  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
469 
470  const index_t size_per_buf =
471  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
472  make_tuple(number<0>{}, number<0>{}, number<0>{})) *
473  sizeof(LdsDataType);
474 
475  const index_t size_per_wave =
476  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
477  make_tuple(number<0>{}, number<1>{}, number<0>{})) *
478  sizeof(LdsDataType) -
479  size_per_buf;
480 
481  const index_t size_per_issue =
482  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
483  make_tuple(number<1>{}, number<0>{}, number<0>{})) *
484  sizeof(LdsDataType) -
485  size_per_buf;
486 
487  const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
488  m0_set_with_memory(m0_init_value); // This should be wave independent
489 
490  using Traits = load_store_traits;
491 
492  // using vector_type_t = typename Traits::vector_type_t;
493  using vector_t = typename Traits::vector_t;
494  using SFC_Ys = typename Traits::SFC_Ys;
495 
496  LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
497 
498  // loop over thread tensor space [y0, y1, ...]
499  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
501  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
502  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
503 
504  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
505  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
506  constexpr auto pre_nop_ = [&]() {
507  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
508  return bool_constant<true>{};
509  else
510  return bool_constant<false>{};
511  }();
512 
513  // read from bottom tensor
514  get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
515  smem, bottom_tensor_thread_coord, 0, pre_nop_);
516 
517  // move thread coordinate
518  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
519  {
520  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
521 
522  constexpr auto idx_diff_ps_ys = container_concat(
523  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
524  idx_diff_ys);
525 
527  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
528 
529  m0_inc_with_memory(size_per_issue);
530  }
531  });
532  });
533  }
534 
535  template <typename LdsTileWindow_,
536  index_t i_access_unsupport_ = -1,
537  bool oob_conditional_check = true>
538  CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
541  {
542  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
543  using LdsDataType = typename LdsTileWindow::DataType;
544 
545  // issues * warps * lanes
546  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
547 
548  // TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
549  // dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
550  // check?)
551  constexpr index_t size_per_buf =
552  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
553  make_tuple(number<0>{}, number<0>{}, number<0>{}));
554 
555  constexpr index_t size_per_wave =
556  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
557  make_tuple(number<0>{}, number<1>{}, number<0>{})) -
558  size_per_buf;
559 
560  constexpr index_t size_per_issue =
561  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
562  make_tuple(number<1>{}, number<0>{}, number<0>{})) -
563  size_per_buf;
564 
565  const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
566 
567  using Traits = load_store_traits;
568 
569  using vector_t = typename Traits::vector_t;
570  using SFC_Ys = typename Traits::SFC_Ys;
571 
572  // TODO: we force CK_TILE_LDS_ADDR
573  CK_TILE_LDS_ADDR LdsDataType* smem =
574  lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
575 
576  // loop over thread tensor space [y0, y1, ...]
577  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
579  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
580  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
581 
582  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
583  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
584 
585  // read from bottom tensor
586  get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
587  smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
588 
589  // move thread coordinate
590  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
591  {
592  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
593 
594  constexpr auto idx_diff_ps_ys = container_concat(
595  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
596  idx_diff_ys);
597 
599  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
600 
601  smem += size_per_issue; // Note we manually increase the per-issue offset
602  }
603  });
604  });
605  }
606 
607  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
611  {
612  using Traits = load_store_traits;
613 
614  // using vector_type_t = typename Traits::vector_type_t;
615  using vector_t = typename Traits::vector_t;
616  using SFC_Ys = typename Traits::SFC_Ys;
617 
618  constexpr auto tile_dstr = TileDstr{};
619 
620  // loop over thread tensor space [y0, y1, ...]
621  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
622  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
623  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
624 
625  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
626  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
627 
628  // data index [y0, y1, ...]
629  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
630 
631  // read from distributed tensor
632  // vector_type_t vec;
633  vector_t vec_value;
634 
635  static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
636  constexpr auto idx_ys = generate_tuple(
637  [&](auto jj) {
638  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
639  : idx_ys_start[jj];
640  },
641  number<NDimY>{});
642 
643  constexpr index_t d =
644  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
645 
646  vec_value.template get_as<DataType>()(j) =
647  dstr_tensor.get_thread_buffer().template at<d>();
648  });
649 
650  // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
651 
652  // write into bottom tensor
653  get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
654  bottom_tensor_thread_coord,
655  0,
656  vec_value,
657  bool_constant<oob_conditional_check>{});
658 
659  // move thread coordinate
660  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
661  {
662  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
663 
664  constexpr auto idx_diff_ps_ys = container_concat(
665  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
666  idx_diff_ys);
667 
669  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
670  }
671  });
672  });
673  }
674 
675  template <index_t i_access_unsupport_ = -1>
677  number<i_access_unsupport_> = {}) const
678  {
679  using Traits = load_store_traits;
680 
681  using vector_t = typename Traits::vector_t;
682  using SFC_Ys = typename Traits::SFC_Ys;
683 
684  constexpr auto tile_dstr = TileDstr{};
685  static constexpr bool oob_conditional_check = true;
686 
687  // loop over thread tensor space [y0, y1, ...]
688  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
690  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
691  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
692 
693  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
694  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
695 
696  // data index [y0, y1, ...]
697  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
698 
699  // read from distributed tensor
700  vector_t vec_value;
701  static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
702  constexpr auto idx_ys = generate_tuple(
703  [&](auto jj) {
704  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
705  : idx_ys_start[jj];
706  },
707  number<NDimY>{});
708  constexpr index_t d =
709  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
710  vec_value.template get_as<DataType>()(j) =
711  dstr_tensor.get_thread_buffer().template at<d>();
712  });
713 
714  // write into bottom tensor
716  .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
717  bottom_tensor_thread_coord, 0, vec_value);
718 
719  // move thread coordinate
720  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
721  {
722  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
723 
724  constexpr auto idx_diff_ps_ys = container_concat(
725  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
726  idx_diff_ys);
727 
729  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
730  }
731  });
732  });
733  }
734 
735  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
739  {
740  using Traits = load_store_traits;
741 
742  using vector_t = typename Traits::vector_t;
743  using SFC_Ys = typename Traits::SFC_Ys;
744 
745  constexpr auto tile_dstr = TileDstr{};
746 
747  // loop over thread tensor space [y0, y1, ...]
748  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
750  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
751  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
752 
753  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
754  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
755 
756  // data index [y0, y1, ...]
757  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
758 
759  // read from distributed tensor
760  vector_t vec_value;
761 
762  static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
763  constexpr auto idx_ys = generate_tuple(
764  [&](auto jj) {
765  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
766  : idx_ys_start[jj];
767  },
768  number<NDimY>{});
769 
770  constexpr index_t d =
771  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
772 
773  vec_value.template get_as<DataType>()(j) =
774  dstr_tensor.get_thread_buffer().template at<d>();
775  });
776 
777  // write into bottom tensor
778  get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
779  bottom_tensor_thread_coord,
780  0,
781  vec_value,
782  bool_constant<oob_conditional_check>{});
783 
784  // move thread coordinate
785  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
786  {
787  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
788 
789  constexpr auto idx_diff_ps_ys = container_concat(
790  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
791  idx_diff_ys);
792 
794  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
795  }
796  });
797  });
798  }
799 
800  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
804  bool_constant<pre_nop> = {}) const
805  {
806  using Traits = load_store_traits;
807 
808  using vector_t = typename Traits::vector_t;
809  using SFC_Ys = typename Traits::SFC_Ys;
810 
811  constexpr auto tile_dstr = TileDstr{};
812 
813  // loop over thread tensor space [y0, y1, ...]
814  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
816  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
817  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
818 
819  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
820  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
821 
822  // data index [y0, y1, ...]
823  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
824 
825  // read from distributed tensor
826  vector_t vec_value;
827 
828  static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
829  constexpr auto idx_ys = generate_tuple(
830  [&](auto jj) {
831  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
832  : idx_ys_start[jj];
833  },
834  number<NDimY>{});
835 
836  constexpr index_t d =
837  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
838 
839  vec_value.template get_as<DataType>()(j) =
840  dstr_tensor.get_thread_buffer().template at<d>();
841  });
842 
843  // write into bottom tensor
844  get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
845  bottom_tensor_thread_coord,
846  0,
847  vec_value,
848  bool_constant<oob_conditional_check>{},
849  bool_constant<pre_nop>{});
850 
851  // move thread coordinate
852  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
853  {
854  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
855 
856  constexpr auto idx_diff_ps_ys = container_concat(
857  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
858  idx_diff_ys);
859 
861  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
862  }
863  });
864  });
865  }
866 
867  // move thread's botom tensor coordiante
868  // [x0', x1', ... ] ==> [offset]
869  // also move window-origin
871  {
872  window_origin_ += step;
873 
874  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
875  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
876  pre_computed_coords_(iCoord)(I1),
877  step);
878  });
879  }
880 
881  CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
882  {
883  window_origin_ = new_window_origin;
884 
885 #if 0 // debug
886  // TODO: this use more register for FA, but less register for GEMM
887  // need investigation
888  // only support warp-tile and block-tile
889  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
890 
891  WindowAdaptorCoord window_adaptor_thread_coord_tmp;
892 
893  if constexpr(NDimP == 1)
894  {
895  window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
896  tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
897  }
898  else if constexpr(NDimP == 2)
899  {
900  window_adaptor_thread_coord_tmp =
901  make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
902  AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
903  }
904 #else
905  // TODO: this use less register for FA, but more register for GEMM
906  // need investigation
907  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
908  tile_dstr_.get_ps_ys_to_xs_adaptor(),
910 #endif
911 
912  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
913  window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
914 
915  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
916  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
917 
918  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
919  // future load/store() calls (might allocate more registers)
920  using Traits = load_store_traits;
921  using SFC_Ys = typename Traits::SFC_Ys;
922 
923  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
924  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
925  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
926 
927  constexpr auto idx_diff_ys =
928  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
929 
930  constexpr auto idx_diff_ps_ys = container_concat(
931  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
932 
934  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
935 
936  pre_computed_coords_(iCoord) =
937  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
938  });
939  }
940 
942 
943  // this is the bottom tensor view
944  // [x0', x1', ...] ==> [offset]
946 
947  //
949 
950  // origin ([x0', x1', ...]) of window on bottom tensor
952 
953  // Tile tensor distribution, which contains:
954  // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
955  // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
957 
958  // this contains:
959  // per-thread coordinate for window adaptor
960  // per-thread coordinate for bottom tensor
962 };
963 
964 // TODO: use strategy
965 template <typename TensorView_,
966  typename WindowLengths_,
967  typename StaticTileDistribution_,
968  index_t NumCoord = 1>
969 CK_TILE_DEVICE constexpr auto
970 make_tile_window(const TensorView_& tensor_view,
971  const WindowLengths_& window_lengths,
972  const multi_index<TensorView_::get_num_of_dimension()>& origin,
973  const StaticTileDistribution_& tile_distribution,
974  number<NumCoord> = {})
975 {
976  return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
977  remove_cvref_t<WindowLengths_>,
978  remove_cvref_t<StaticTileDistribution_>,
979  NumCoord>{
980  tensor_view, window_lengths, origin, tile_distribution};
981 }
982 
983 // this version can't be called in a constexpr context
984 template <typename TensorView_,
985  typename WindowLengths_,
986  typename StaticTileDistribution_,
987  index_t NumCoord = 1>
988 CK_TILE_DEVICE auto
990  const WindowLengths_& window_lengths,
991  const multi_index<TensorView_::get_num_of_dimension()>& origin,
992  const StaticTileDistribution_& tile_distribution,
993  number<NumCoord> = {})
994 {
995  auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
996  remove_cvref_t<WindowLengths_>,
997  remove_cvref_t<StaticTileDistribution_>,
998  NumCoord>{
999  tensor_view, window_lengths, origin, tile_distribution};
1000  w.init_raw();
1001  return w;
1002 }
1003 
1004 template <typename TensorView_,
1005  typename WindowLengths_,
1006  typename StaticTileDistribution_,
1007  index_t NumCoord>
1010  WindowLengths_,
1011  StaticTileDistribution_,
1012  NumCoord>& window,
1013  const typename tile_window_with_static_distribution<TensorView_,
1014  WindowLengths_,
1015  StaticTileDistribution_,
1016  NumCoord>::BottomTensorIndex& step)
1017 {
1018  window.move(step);
1019 }
1020 
1029 template <typename BottomTensorView_, typename WindowLengths_>
1031 {
1034  using BottomTensorDesc = typename BottomTensorView::TensorDesc;
1035  using DataType = typename BottomTensorView::DataType;
1036 
1037  static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
1038 
1040  "wrong! lengths should be static");
1041 
1043 
1045 
1047  const BottomTensorView& bottom_tensor_view,
1048  const WindowLengths& window_lengths,
1049  const BottomTensorIndex& window_origin)
1050  : bottom_tensor_view_{bottom_tensor_view},
1051  window_lengths_{window_lengths},
1052  window_origin_{window_origin}
1053  {
1054  }
1055 
1057 
1058  CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
1059 
1060  CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
1061 
1062  CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
1063 
1064  CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
1065  {
1066  window_origin_ = new_window_origin;
1067  }
1068 
1069  CK_TILE_DEVICE constexpr void
1070  set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
1071  {
1072  bottom_tensor_view_.buf_.p_data_ = data;
1073  }
1074 
1075  // move window-origin
1076  CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; }
1077 
1078  // this is the bottom tensor view
1079  // [x0', x1', ...] ==> [offset]
1081 
1082  //
1084 
1085  // origin ([x0', x1', ...]) of window on bottom tensor
1087 };
1088 
1089 template <typename TensorView_, typename WindowLengths_>
1090 CK_TILE_DEVICE constexpr auto
1091 make_tile_window(const TensorView_& tensor_view,
1092  const WindowLengths_& window_lengths,
1093  const multi_index<TensorView_::get_num_of_dimension()>& origin)
1094 {
1096  "wrong! lengths should be static");
1097 
1100  tensor_view, window_lengths, origin};
1101 }
1102 
1103 // duplicate tile window and replace its origin
1104 template <typename TensorView, typename WindowLengths>
1105 CK_TILE_DEVICE constexpr auto
1107  const multi_index<TensorView::get_num_of_dimension()>& origin)
1108 {
1110  tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), origin};
1111 }
1112 
1113 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
1114 CK_TILE_DEVICE constexpr auto
1116  const multi_index<TensorView::get_num_of_dimension()>& origin,
1117  const StaticTileDistribution& tile_distribution)
1118 {
1119  return make_tile_window(tile_window.get_bottom_tensor_view(),
1120  tile_window.get_window_lengths(),
1121  origin,
1123 }
1124 
1125 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
1126 CK_TILE_DEVICE constexpr auto
1128  const StaticTileDistribution& tile_distribution)
1129 {
1130  return make_tile_window(tile_window.get_bottom_tensor_view(),
1131  tile_window.get_window_lengths(),
1132  tile_window.get_window_origin(),
1134 }
1135 
1136 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
1137 CK_TILE_DEVICE constexpr auto
1139  const StaticTileDistribution& tile_distribution)
1140 {
1141  auto w = make_tile_window(tile_window.get_bottom_tensor_view(),
1142  tile_window.get_window_lengths(),
1143  tile_window.get_window_origin(),
1145  w.init_raw();
1146  return w;
1147 }
1148 
1149 template <typename TensorView_, typename WindowLengths_>
1153  step)
1154 {
1155  window.move(step);
1156 }
1157 
1158 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_LDS_ADDR
Definition: config.hpp:56
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
tuple_array< T, N > thread_buffer
Definition: thread_buffer.hpp:14
CK_TILE_DEVICE auto make_tile_window_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, number< NumCoord >={})
Definition: tile_window.hpp:989
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constant< b > bool_constant
Definition: integral_constant.hpp:39
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constant< v > number
Definition: integral_constant.hpp:33
CK_TILE_DEVICE index_t get_warp_id()
Definition: arch.hpp:71
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1046
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:14
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:97
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:86
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
Definition: sequence.hpp:278
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:293
Definition: array.hpp:24
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:75
Definition: space_filling_curve.hpp:20
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:56
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
static constexpr index_t NumAccess
Definition: tile_window.hpp:137
static constexpr index_t VectorDimY
Definition: tile_window.hpp:100
thread_buffer< DataType, ScalarPerVector > vector_t
Definition: tile_window.hpp:106
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_window.hpp:135
static constexpr index_t ScalarPerVector
Definition: tile_window.hpp:101
This class provides tile (windowed) view and access to the device memory.
Definition: tile_window.hpp:37
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_window.hpp:217
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_window.hpp:73
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_window.hpp:47
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_window.hpp:67
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_window.hpp:215
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_window.hpp:224
static constexpr index_t NDimY
Definition: tile_window.hpp:51
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_window.hpp:39
CK_TILE_DEVICE void update_raw(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:801
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window.hpp:870
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window.hpp:226
constexpr CK_TILE_DEVICE tile_window_with_static_distribution()=default
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_window.hpp:296
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window.hpp:228
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:538
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_window.hpp:257
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}) const
Definition: tile_window.hpp:676
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:311
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:608
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:299
static constexpr index_t NDimBottomTensor
Definition: tile_window.hpp:48
BottomTensorIndex window_origin_
Definition: tile_window.hpp:951
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_window.hpp:941
CK_TILE_DEVICE void update(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:736
static constexpr auto I0
Definition: tile_window.hpp:53
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window.hpp:222
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_window.hpp:961
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_window.hpp:40
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_window.hpp:45
CK_TILE_DEVICE void load_raw(DstTile &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:381
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_window.hpp:38
static constexpr auto I1
Definition: tile_window.hpp:54
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_window.hpp:42
WindowLengths window_lengths_
Definition: tile_window.hpp:948
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_window.hpp:43
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:457
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_window.hpp:881
TileDstr tile_dstr_
Definition: tile_window.hpp:956
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_window.hpp:231
BottomTensorView bottom_tensor_view_
Definition: tile_window.hpp:945
constexpr CK_TILE_DEVICE tile_window_with_static_distribution(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution)
Definition: tile_window.hpp:147
static constexpr index_t NumAccessPerCoord
Definition: tile_window.hpp:143
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_window.hpp:239
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_window.hpp:70
static constexpr index_t NDimP
Definition: tile_window.hpp:50
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_window.hpp:66
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1031
static constexpr index_t NDimBottomTensor
Definition: tile_window.hpp:1037
BottomTensorView bottom_tensor_view_
Definition: tile_window.hpp:1080
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_window.hpp:1056
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window.hpp:1058
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window.hpp:1060
constexpr CK_TILE_DEVICE tile_window_with_static_lengths(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin)
Definition: tile_window.hpp:1046
constexpr CK_TILE_DEVICE tile_window_with_static_lengths()=default
typename BottomTensorView::DataType DataType
Definition: tile_window.hpp:1035
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window.hpp:1076
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_window.hpp:1034
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window.hpp:1062
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_window.hpp:1032
BottomTensorIndex window_origin_
Definition: tile_window.hpp:1086
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_window.hpp:1064
WindowLengths window_lengths_
Definition: tile_window.hpp:1083
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_window.hpp:1033
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_window.hpp:1070
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10