/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File
tile_scatter_gather.hpp
Go to the documentation of this file.
1 
2 // SPDX-License-Identifier: MIT
3 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
4 
5 #pragma once
6 
19 
20 namespace ck_tile {
21 
33 template <typename BottomTensorView_,
34  typename WindowLengths_,
35  typename StaticTileDistribution_,
36  typename StaticPageIndexArray_,
37  typename StaticValidArray_,
38  index_t HsGatherDim = 0,
39  index_t NumCoord = 1,
40  index_t YsGatherDim = 0>
42 {
48  using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
49  using BottomTensorDesc = typename BottomTensorView::TensorDesc;
50 
52 
53  static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
54  static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
55 
56  static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
57  static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
58 
59  static constexpr auto I0 = number<0>{};
60  static constexpr auto I1 = number<1>{};
61  static_assert(NumCoord == 1);
62 
63  // TODO: check WindowLengths and StaticTileDistribution are consistent
64 
66  "wrong! lengths should be static");
67  static_assert(TileDstr::is_static(), "wrong!");
68 
69  static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
70  "wrong! inconsistent # of diemsnions");
71 
74 
77 
80 
82  {
83  private:
84  static constexpr auto get_vector_dim_y_scalar_per_vector()
85  {
86  const auto [ys_vector_lengths, ys_vector_strides] =
88 
89  index_t VectorDimY_ = 0;
90  index_t ScalarPerVector_ = 1;
91 
92  for(index_t i = 0; i < NDimY; ++i)
93  {
94  if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
95  {
96  ScalarPerVector_ = ys_vector_lengths[i];
97  VectorDimY_ = i;
98  }
99  }
100 
101  return make_tuple(VectorDimY_, ScalarPerVector_);
102  }
103 
104  public:
105  static constexpr index_t PackedSize =
107  static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
108  static constexpr index_t ScalarPerVector =
109  get_vector_dim_y_scalar_per_vector().template at<1>();
110 
111  // using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
112  // using vector_t = typename vector_type_t::type;
114 
115  private:
116  static constexpr auto scalars_per_access_ = [] {
117  constexpr auto scalars_per_access_arr = generate_array(
118  [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
119 
121  constexpr auto NDimY_ = NDimY;
122 
123  return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
124  }();
125 
126  static constexpr auto get_space_filling_curve()
127  {
128  constexpr auto tile_dstr = TileDstr{};
129 
130  constexpr auto thread_tensor_lengths_ys =
131  to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
132 
133  // FIXME: need logic to judge dim access order
134  using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
135 
136  return space_filling_curve<decltype(thread_tensor_lengths_ys),
137  DimAccessOrder,
138  decltype(scalars_per_access_)>{};
139  }
140 
141  public:
142  using SFC_Ys = decltype(get_space_filling_curve());
143 
144  static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
145 
146  static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
147  static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
148  };
149 
151 
152  CK_TILE_DEVICE constexpr tile_scatter_gather() = default;
153 
154  CK_TILE_DEVICE constexpr tile_scatter_gather(const BottomTensorView& bottom_tensor_view,
155  const WindowLengths& window_lengths,
156  const BottomTensorIndex& window_origin,
158  const PageIdxArray& page_idx,
159  const ValidArray& valids)
160  : bottom_tensor_view_{bottom_tensor_view},
161  window_lengths_{window_lengths},
162  window_origin_{window_origin},
164  page_idx_{page_idx},
165  valids_{valids},
167  {
168 #if 0 // debug
169  // TODO: this use more register for FA, but less register for GEMM
170  // need investigation
171  // only support warp-tile and block-tile
172  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
173 
174  WindowAdaptorCoord window_adaptor_thread_coord_tmp;
175 
176  if constexpr(NDimP == 1)
177  {
178  window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
180  }
181  else if constexpr(NDimP == 2)
182  {
183  window_adaptor_thread_coord_tmp =
185  AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
186  }
187 #else
188  // TODO: this use less register for FA, but more register for GEMM
189  // need investigation
190  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
194 #endif
195 
196  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
197  window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
198  bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
199  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
200  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
201 
202  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
203  // future load/store() calls (might allocate more registers)
204  using Traits = load_store_traits;
205  using SFC_Ys = typename Traits::SFC_Ys;
206 
207  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
208  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
209  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
210 
211  constexpr auto idx_diff_ys =
212  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
213 
214  constexpr auto idx_diff_ps_ys = container_concat(
215  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
216 
218  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
219 
220  pre_computed_coords_(iCoord) =
221  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
222  });
223  }
224 
226 
228  {
229  return TileDstr::is_static();
230  }
231 
232  CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
233 
234  CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
235 
236  CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
237 
238  CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
239 
240  CK_TILE_DEVICE constexpr void
241  set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
242  {
243  bottom_tensor_view_.buf_.p_data_ = data;
244  }
245 
246  // move thread's window adaptor coordinate and bottom tensor coordinate
247  // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
248  template <typename ATopIndex>
250  WindowAdaptorCoord& window_adaptor_thread_coord,
251  BottomTensorCoord& bottom_tensor_thread_coord,
252  const ATopIndex& idx_diff_adaptor_top) const
253  {
254  array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
255 
256  move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
257  window_adaptor_thread_coord,
258  idx_diff_adaptor_top,
259  idx_diff_adaptor_bottom);
260 
261  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
262  bottom_tensor_thread_coord,
263  idx_diff_adaptor_bottom);
264  }
265 
266  // return vector dimension among [y0, y1, ...]
268  {
269  // bottom tensor top dimension vector lengths and strides
270  const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
271  BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
272 
273  // window vector lengths/strides
274  const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
275  const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
276 
277  // window adaptor [p0, p1, ..., y0, y1, ...]
278  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
279  -1};
280  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
281  -1};
282 
283  constexpr auto window_adaptor_bottom_dims =
284  WindowAdaptor::get_bottom_dimension_hidden_ids();
285 
286  set_container_subset(window_adaptor_vector_lengths,
287  window_adaptor_bottom_dims,
288  window_adaptor_bottom_dim_vector_lengths);
289  set_container_subset(window_adaptor_vector_strides,
290  window_adaptor_bottom_dims,
291  window_adaptor_bottom_dim_vector_strides);
292 
293  const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
294  WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
295  window_adaptor_vector_lengths, window_adaptor_vector_strides);
296 
297  // [y0, y1, ...]
298  constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
300  1>::type{};
301 
302  return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
303  get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
304  }
305 
307 
308  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
311  {
312  constexpr auto tile_dstr = TileDstr{};
313  auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
314  load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
315  return dst_tensor;
316  }
317 
318  template <typename DistributedTensor,
319  index_t i_access_unsupport_ = -1,
320  bool oob_conditional_check = true>
321  CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
324  {
325  using Traits = load_store_traits;
326  using vector_t = typename Traits::vector_t;
327  using SFC_Ys = typename Traits::SFC_Ys;
328 
329  constexpr auto tile_dstr = TileDstr{};
330 
331  // loop over thread tensor space [y0, y1, ...]
332  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
334  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
335  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
336 
337  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
338  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
339 
340  // data index [y0, y1, ...]
341  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
342  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
343  const auto page_offset = page_idx_[idx_gather];
344 
345  // read from bottom tensor
346  const vector_t vec_value = [&]() {
347  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
348  {
349  return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
350  bottom_tensor_thread_coord,
351  page_offset,
352  bool_constant<oob_conditional_check>{});
353  }
354  else
355  {
356  return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
357  bottom_tensor_thread_coord,
358  page_offset,
359  valids_[idx_gather],
360  bool_constant<oob_conditional_check>{});
361  }
362  }();
363 #if 1
364  // write into distributed tensor
365  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
366  constexpr auto idx_ys = generate_tuple(
367  [&](auto jj) {
368  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
369  : idx_ys_start[jj];
370  },
371  number<NDimY>{});
372 
373  constexpr index_t d =
374  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
375  Traits::PackedSize;
376 
377  dst_tensor.get_thread_buffer().template at<d>() =
378  vec_value.template get_as<DataType>()[j / Traits::PackedSize];
379  });
380 #else
381  constexpr index_t d =
382  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
383  static_assert(d % Traits::ScalarPerVector == 0);
384 
385  dst_tensor.get_thread_buffer().template get_as<vector_t>()(
386  number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
387 #endif
388  // move thread coordinate
389  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
390  {
391  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
392 
393  constexpr auto forward_step_scatter = generate_tuple(
394  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
395  number<NDimY>{});
396 
397  constexpr auto idx_diff_ps_ys = container_concat(
398  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
399  forward_step_scatter);
400 
402  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
403  }
404  });
405  });
406  }
407 
408  template <typename LdsTileWindow_,
409  index_t i_access_unsupport_ = -1,
410  bool oob_conditional_check = true>
411  CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
414  {
415  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
416  using LdsDataType = typename LdsTileWindow::DataType;
417  using Traits = load_store_traits;
418  using vector_t = typename Traits::vector_t;
419  using SFC_Ys = typename Traits::SFC_Ys;
420 
421  constexpr auto tile_dstr = TileDstr{};
422 
423  // Precompute invariant values outside loops
424  const auto window_origin = lds_tile.get_window_origin();
425  const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
426  const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
427  auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
428 
429  // loop over thread tensor space [y0, y1, ...]
430  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
432  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
433  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
434 
435  auto lds_window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
436  auto lds_bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
437 
438  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
439  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
440 
441  // Use precomputed window origin
442  auto lds_bottom_tensor_thread_idx =
443  window_origin + lds_window_adaptor_thread_coord.get_bottom_index();
444  // Use precomputed tensor descriptor
445  const auto lds_coord =
446  make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
447  // Calculate SMEM address using base pointer
448  CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
449 
450  // data index [y0, y1, ...]
451  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
452  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
453  const auto page_offset = page_idx_[idx_gather];
454 
455  // merge page_offset into bottom_coord
456  auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
457  mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
458 
459  // read from bottom tensor
460  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
461  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
462  smem,
463  mixed_bottom_thread_coord,
464  number<0>{},
465  bool_constant<oob_conditional_check>{});
466  else
467  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
468  smem,
469  mixed_bottom_thread_coord,
470  number<0>{},
471  valids_[idx_gather],
472  bool_constant<oob_conditional_check>{});
473 
474  // move thread coordinate
475  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
476  {
477  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
478 
479  constexpr auto forward_step_scatter = generate_tuple(
480  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
481  number<NDimY>{});
482 
483  constexpr auto idx_diff_ps_ys = container_concat(
484  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
485  forward_step_scatter);
486  // lds_diff doesn't need to mask the difference of the gather-dim.
487  constexpr auto lds_idx_diff_ps_ys = container_concat(
488  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
489  idx_diff_ys);
490 
492  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
494  lds_window_adaptor_thread_coord,
495  lds_bottom_tensor_thread_coord,
496  lds_idx_diff_ps_ys);
497  }
498  });
499  });
500  }
501 
502  // TODO: currently async load only implemented in inline asm
503  template <typename LdsTileWindow_,
504  index_t i_access_unsupport_ = -1,
505  bool oob_conditional_check = true,
506  bool pre_nop = false>
507  CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
510  bool_constant<pre_nop> = {}) const
511  {
512  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
513  // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
514  using LdsDataType = typename LdsTileWindow::DataType;
515  // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
516 
517  // issues * warps * lanes
518  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
519 
520  const index_t size_per_buf =
521  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
522  make_tuple(number<0>{}, number<0>{}, number<0>{})) *
523  sizeof(LdsDataType);
524 
525  const index_t size_per_wave =
526  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
527  make_tuple(number<0>{}, number<1>{}, number<0>{})) *
528  sizeof(LdsDataType) -
529  size_per_buf;
530 
531  const index_t size_per_issue =
532  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
533  make_tuple(number<1>{}, number<0>{}, number<0>{})) *
534  sizeof(LdsDataType) -
535  size_per_buf;
536 
537  const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
538  m0_set_with_memory(m0_init_value); // This should be wave independent
539 
540  using Traits = load_store_traits;
541 
542  // using vector_type_t = typename Traits::vector_type_t;
543  using vector_t = typename Traits::vector_t;
544  using SFC_Ys = typename Traits::SFC_Ys;
545 
546  LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
547 
548  // loop over thread tensor space [y0, y1, ...]
549  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
551  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
552  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
553 
554  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
555  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
556  constexpr auto pre_nop_ = [&]() {
557  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
558  return bool_constant<true>{};
559  else
560  return bool_constant<false>{};
561  }();
562 
563  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
564  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
565  const auto page_offset = page_idx_[idx_gather];
566 
567  // read from bottom tensor
568  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
569  {
570  get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
571  smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
572  }
573  else
574  {
575  get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
576  smem,
577  bottom_tensor_thread_coord,
578  page_offset,
579  valids_[idx_gather],
580  0,
581  pre_nop_);
582  }
583 
584  // move thread coordinate
585  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
586  {
587  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
588 
589  constexpr auto forward_step_scatter = generate_tuple(
590  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
591  number<NDimY>{});
592 
593  constexpr auto idx_diff_ps_ys = container_concat(
594  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
595  forward_step_scatter);
596 
598  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
599 
600  m0_inc_with_memory(size_per_issue);
601  }
602  });
603  });
604  }
605 
606  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
610  {
611  using Traits = load_store_traits;
612 
613  // using vector_type_t = typename Traits::vector_type_t;
614  using vector_t = typename Traits::vector_t;
615  using SFC_Ys = typename Traits::SFC_Ys;
616 
617  constexpr auto tile_dstr = TileDstr{};
618 
619  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
620  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
621  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
622 
623  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
624  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
625 
626  // data index [y0, y1, ...]
627  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
628  constexpr auto idx_gather = idx_ys_start[number<0>{}];
629  const auto page_offset = page_idx_[idx_gather];
630 
631  // read from distributed tensor
632  vector_t vec_value;
633 
634  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
635  constexpr auto idx_ys = generate_tuple(
636  [&](auto jj) {
637  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
638  : idx_ys_start[jj];
639  },
640  number<NDimY>{});
641 
642  constexpr index_t d =
643  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
644  Traits::PackedSize;
645 
646  vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
647  dstr_tensor.get_thread_buffer().template at<d>();
648  });
649 
650  // write into bottom tensor
651  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
652  {
653  get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
654  bottom_tensor_thread_coord,
655  page_offset,
656  vec_value,
657  bool_constant<oob_conditional_check>{});
658  }
659  else
660  {
661  get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
662  bottom_tensor_thread_coord,
663  page_offset,
664  valids_[idx_gather],
665  vec_value,
666  bool_constant<oob_conditional_check>{});
667  }
668 
669  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
670  {
671  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
672 
673  constexpr auto forward_step_scatter = generate_tuple(
674  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
675  number<NDimY>{});
676 
677  constexpr auto idx_diff_ps_ys = container_concat(
678  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
679  forward_step_scatter);
680 
682  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
683  }
684  });
685  });
686  }
687 
688  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
692  {
693  using Traits = load_store_traits;
694 
695  // using vector_type_t = typename Traits::vector_type_t;
696  using vector_t = typename Traits::vector_t;
697  using SFC_Ys = typename Traits::SFC_Ys;
698 
699  constexpr auto tile_dstr = TileDstr{};
700  // printf("off %d\n", page_idx_[I0]);
701  // loop over thread tensor space [y0, y1, ...]
702  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
703  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
704  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
705 
706  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
707  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
708 
709  // data index [y0, y1, ...]
710  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
711  constexpr auto idx_gather = idx_ys_start[number<0>{}];
712  const auto page_offset = page_idx_[idx_gather];
713 
714  // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
715  // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
716 
717  // read from distributed tensor
718  // vector_type_t vec;
719  vector_t vec_value;
720 
721  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
722  constexpr auto idx_ys = generate_tuple(
723  [&](auto jj) {
724  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
725  : idx_ys_start[jj];
726  },
727  number<NDimY>{});
728 
729  constexpr index_t d =
730  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
731  Traits::PackedSize;
732  // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
733  vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
734  dstr_tensor.get_thread_buffer().template at<d>();
735  });
736 
737  // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
738 
739  // write into bottom tensor
740  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
741  {
742  get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
743  bottom_tensor_thread_coord,
744  page_offset,
745  vec_value,
746  bool_constant<oob_conditional_check>{});
747  }
748  else
749  {
750  get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
751  bottom_tensor_thread_coord,
752  page_offset,
753  valids_[idx_gather],
754  vec_value,
755  bool_constant<oob_conditional_check>{});
756  }
757 
758  // printf("coord_offset:%d, scatter_offset:%d \n",
759  // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
760  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
761  {
762  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
763 
764  constexpr auto forward_step_scatter = generate_tuple(
765  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
766  number<NDimY>{});
767 
768  constexpr auto idx_diff_ps_ys = container_concat(
769  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
770  forward_step_scatter);
771 
773  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
774  }
775  });
776  });
777  }
778 
779  // move thread's botom tensor coordiante
780  // [x0', x1', ... ] ==> [offset]
781  // also move window-origin
783  {
784  window_origin_ += step;
785  BottomTensorIndex step_new = step;
786  step_new(HsGatherDim) = 0;
787  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
788  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
789  pre_computed_coords_(iCoord)(I1),
790  step_new);
791  });
792  }
793 
794  CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
795 
796  CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
797  {
798  if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
799  {
800  valids_ = new_valids;
801  }
802  }
803 
805  const ValidArray& new_valids)
806  {
807  update_page_idx(new_idx);
808  update_valids(new_valids);
809  }
810 
811  CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
812  {
813  window_origin_ = new_window_origin;
814 
815 #if 0 // debug
816  // TODO: this use more register for FA, but less register for GEMM
817  // need investigation
818  // only support warp-tile and block-tile
819  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
820 
821  WindowAdaptorCoord window_adaptor_thread_coord_tmp;
822 
823  if constexpr(NDimP == 1)
824  {
825  window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
826  tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
827  }
828  else if constexpr(NDimP == 2)
829  {
830  window_adaptor_thread_coord_tmp =
831  make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
832  AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
833  }
834 #else
835  // TODO: this use less register for FA, but more register for GEMM
836  // need investigation
837  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
838  tile_dstr_.get_ps_ys_to_xs_adaptor(),
840 #endif
841 
842  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
843  window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
844 
845  bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
846  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
847  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
848 
849  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
850  // future load/store() calls (might allocate more registers)
851  using Traits = load_store_traits;
852  using SFC_Ys = typename Traits::SFC_Ys;
853 
854  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
855  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
856  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
857 
858  constexpr auto idx_diff_ys =
859  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
860 
861  constexpr auto idx_diff_ps_ys = container_concat(
862  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
863 
865  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
866 
867  pre_computed_coords_(iCoord) =
868  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
869  });
870  }
871 
873 
874  // this is the bottom tensor view
875  // [x0', x1', ...] ==> [offset]
877 
878  //
880 
881  // origin ([x0', x1', ...]) of window on bottom tensor
883 
884  // Tile tensor distribution, which contains:
885  // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
886  // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
888 
891 
892  // this contains:
893  // per-thread coordinate for window adaptor
894  // per-thread coordinate for bottom tensor
896 };
897 
898 // TODO: use strategy
899 template <typename TensorView_,
900  typename WindowLengths_,
901  typename StaticTileDistribution_,
902  typename StaticPageIndexArray_,
903  index_t HsGatherDim = 0,
904  index_t NumCoord = 1>
905 CK_TILE_DEVICE constexpr auto
907  const WindowLengths_& window_lengths,
908  const multi_index<TensorView_::get_num_of_dimension()>& origin,
909  const StaticTileDistribution_& tile_distribution,
910  const StaticPageIndexArray_& page_idx,
911  number<HsGatherDim> = {},
912  number<NumCoord> = {})
913 {
914  return tile_scatter_gather<remove_cvref_t<TensorView_>,
915  remove_cvref_t<WindowLengths_>,
916  remove_cvref_t<StaticTileDistribution_>,
917  remove_cvref_t<StaticPageIndexArray_>,
918  std::nullptr_t,
919  HsGatherDim,
920  NumCoord>{
921  tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
922 }
923 
924 template <typename TensorView,
925  typename WindowLengths,
926  typename StaticTileDistribution,
927  typename StaticPageIndexArray,
928  index_t HsGatherDim>
931  const multi_index<TensorView::get_num_of_dimension()>& origin,
932  const StaticTileDistribution& tile_distribution,
933  const StaticPageIndexArray& page_idx,
934  number<HsGatherDim> = {})
935 {
937  tile_window.get_window_lengths(),
938  origin,
939  tile_distribution,
940  page_idx,
941  number<HsGatherDim>{});
942 }
943 
944 template <typename TensorView,
945  typename WindowLengths,
946  typename StaticTileDistribution,
947  typename StaticPageIndexArray,
948  index_t HsGatherDim>
951  const StaticTileDistribution& tile_distribution,
952  const StaticPageIndexArray& page_idx,
953  number<HsGatherDim> = {})
954 {
956  tile_window.get_window_lengths(),
957  tile_window.get_window_origin(),
958  tile_distribution,
959  page_idx,
960  number<HsGatherDim>{});
961 }
962 
963 template <typename TensorView_,
964  typename WindowLengths_,
965  typename StaticTileDistribution_,
966  typename StaticPageIndexArray_,
967  typename StaticValidArray_,
968  index_t HsGatherDim = 0,
969  index_t NumCoord = 1>
970 CK_TILE_DEVICE constexpr auto
972  const WindowLengths_& window_lengths,
973  const multi_index<TensorView_::get_num_of_dimension()>& origin,
974  const StaticTileDistribution_& tile_distribution,
975  const StaticPageIndexArray_& page_idx,
976  const StaticValidArray_& valids,
977  number<HsGatherDim> = {},
978  number<NumCoord> = {})
979 {
980  return tile_scatter_gather<remove_cvref_t<TensorView_>,
981  remove_cvref_t<WindowLengths_>,
982  remove_cvref_t<StaticTileDistribution_>,
983  remove_cvref_t<StaticPageIndexArray_>,
984  remove_cvref_t<StaticValidArray_>,
985  HsGatherDim,
986  NumCoord>{
987  tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
988 }
989 
990 template <typename TensorView,
991  typename WindowLengths,
992  typename StaticTileDistribution,
993  typename StaticPageIndexArray,
994  typename StaticValidArray,
995  index_t HsGatherDim>
998  const multi_index<TensorView::get_num_of_dimension()>& origin,
999  const StaticTileDistribution& tile_distribution,
1000  const StaticPageIndexArray& page_idx,
1001  const StaticValidArray& valids,
1002  number<HsGatherDim> = {})
1003 {
1004  return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
1005  tile_window.get_window_lengths(),
1006  origin,
1007  tile_distribution,
1008  page_idx,
1009  valids,
1010  number<HsGatherDim>{});
1011 }
1012 
1013 template <typename TensorView,
1014  typename WindowLengths,
1015  typename StaticTileDistribution,
1016  typename StaticPageIndexArray,
1017  typename StaticValidArray,
1018  index_t HsGatherDim>
1021  const StaticTileDistribution& tile_distribution,
1022  const StaticPageIndexArray& page_idx,
1023  const StaticValidArray& valids,
1024  number<HsGatherDim> = {})
1025 {
1026  return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
1027  tile_window.get_window_lengths(),
1028  tile_window.get_window_origin(),
1029  tile_distribution,
1030  page_idx,
1031  valids,
1032  number<HsGatherDim>{});
1033 }
1034 
1035 template <typename NewTensorView_,
1036  typename OldTensorView_,
1037  typename WindowLengths_,
1038  typename StaticTileDistribution_,
1039  typename StaticPageIndexArray_,
1040  typename StaticValidArray_,
1041  index_t HsGatherDim = 0,
1042  index_t NumCoord = 1>
1043 CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_& new_tensor_view,
1044  const tile_scatter_gather<OldTensorView_,
1045  WindowLengths_,
1046  StaticTileDistribution_,
1047  StaticPageIndexArray_,
1048  StaticValidArray_,
1049  HsGatherDim,
1050  NumCoord>& tile_window)
1051 {
1052  return make_tile_scatter_gather(new_tensor_view,
1053  tile_window.window_lengths_,
1054  tile_window.window_origin_,
1055  tile_window.tile_dstr_,
1056  tile_window.page_idx_,
1057  tile_window.valids_);
1058 }
1059 
1060 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_ &new_tensor_view, const tile_scatter_gather< OldTensorView_, WindowLengths_, StaticTileDistribution_, StaticPageIndexArray_, StaticValidArray_, HsGatherDim, NumCoord > &tile_window)
Definition: tile_scatter_gather.hpp:1043
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constant< b > bool_constant
Definition: integral_constant.hpp:43
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1115
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1055
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:15
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:906
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:97
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
Definition: sequence.hpp:287
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:302
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Definition: numeric.hpp:81
Definition: space_filling_curve.hpp:20
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: debug.hpp:67
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
Definition: tile_scatter_gather.hpp:82
static constexpr index_t PackedSize
Definition: tile_scatter_gather.hpp:105
static constexpr index_t NumAccess
Definition: tile_scatter_gather.hpp:144
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_scatter_gather.hpp:142
static constexpr index_t VectorDimY
Definition: tile_scatter_gather.hpp:107
static constexpr index_t ScalarPerVector
Definition: tile_scatter_gather.hpp:108
This class provides tile (windowed) view and access to the device memory.
Definition: tile_scatter_gather.hpp:42
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_scatter_gather.hpp:782
static constexpr index_t NumAccessPerCoord
Definition: tile_scatter_gather.hpp:150
static constexpr auto I1
Definition: tile_scatter_gather.hpp:60
constexpr CK_TILE_DEVICE tile_scatter_gather(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution, const PageIdxArray &page_idx, const ValidArray &valids)
Definition: tile_scatter_gather.hpp:154
BottomTensorIndex window_origin_
Definition: tile_scatter_gather.hpp:882
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:309
WindowLengths window_lengths_
Definition: tile_scatter_gather.hpp:879
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_scatter_gather.hpp:234
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_scatter_gather.hpp:306
static constexpr index_t NDimBottomTensor
Definition: tile_scatter_gather.hpp:54
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_scatter_gather.hpp:267
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_scatter_gather.hpp:73
PageIdxArray page_idx_
Definition: tile_scatter_gather.hpp:889
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:411
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_scatter_gather.hpp:44
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_scatter_gather.hpp:811
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_scatter_gather.hpp:895
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_scatter_gather.hpp:238
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_scatter_gather.hpp:45
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_scatter_gather.hpp:249
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:321
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:689
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray &new_idx, const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:804
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_scatter_gather.hpp:49
TileDstr tile_dstr_
Definition: tile_scatter_gather.hpp:887
ValidArray valids_
Definition: tile_scatter_gather.hpp:890
static constexpr index_t NDimY
Definition: tile_scatter_gather.hpp:57
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_scatter_gather.hpp:51
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_scatter_gather.hpp:53
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_scatter_gather.hpp:227
remove_cvref_t< StaticValidArray_ > ValidArray
Definition: tile_scatter_gather.hpp:47
static constexpr index_t NDimP
Definition: tile_scatter_gather.hpp:56
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_scatter_gather.hpp:43
constexpr CK_TILE_DEVICE tile_scatter_gather()=default
CK_TILE_DEVICE void update(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:607
remove_cvref_t< StaticPageIndexArray_ > PageIdxArray
Definition: tile_scatter_gather.hpp:46
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_scatter_gather.hpp:232
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_scatter_gather.hpp:872
static constexpr auto I0
Definition: tile_scatter_gather.hpp:59
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_scatter_gather.hpp:79
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_scatter_gather.hpp:236
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_scatter_gather.hpp:48
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_scatter_gather.hpp:76
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_scatter_gather.hpp:241
BottomTensorView bottom_tensor_view_
Definition: tile_scatter_gather.hpp:876
CK_TILE_DEVICE void update_valids(const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:796
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_scatter_gather.hpp:507
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_scatter_gather.hpp:72
CK_TILE_DEVICE void update_page_idx(const PageIdxArray &new_idx)
Definition: tile_scatter_gather.hpp:794
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_scatter_gather.hpp:225
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
BottomTensorIndex window_origin_
Definition: tile_window_base.hpp:79
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
WindowLengths window_lengths_
Definition: tile_window_base.hpp:81
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1016
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10