/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File
tile_scatter_gather.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
18 
19 namespace ck_tile {
20 
32 template <typename BottomTensorView_,
33  typename WindowLengths_,
34  typename StaticTileDistribution_,
35  typename StaticPageIndexArray_,
36  typename StaticValidArray_,
37  index_t HsGatherDim = 0,
38  index_t NumCoord = 1,
39  index_t YsGatherDim = 0>
41 {
47  using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
48  using BottomTensorDesc = typename BottomTensorView::TensorDesc;
49 
51 
52  static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
53  static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
54 
55  static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
56  static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
57 
58  static constexpr auto I0 = number<0>{};
59  static constexpr auto I1 = number<1>{};
60  static_assert(NumCoord == 1);
61 
62  // TODO: check WindowLengths and StaticTileDistribution are consistent
63 
65  "wrong! lengths should be static");
66  static_assert(TileDstr::is_static(), "wrong!");
67 
68  static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
69  "wrong! inconsistent # of diemsnions");
70 
73 
76 
79 
81  {
82  private:
83  static constexpr auto get_vector_dim_y_scalar_per_vector()
84  {
85  const auto [ys_vector_lengths, ys_vector_strides] =
87 
88  index_t VectorDimY_ = 0;
89  index_t ScalarPerVector_ = 1;
90 
91  for(index_t i = 0; i < NDimY; ++i)
92  {
93  if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
94  {
95  ScalarPerVector_ = ys_vector_lengths[i];
96  VectorDimY_ = i;
97  }
98  }
99 
100  return make_tuple(VectorDimY_, ScalarPerVector_);
101  }
102 
103  public:
104  static constexpr index_t PackedSize =
106  static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
107  static constexpr index_t ScalarPerVector =
108  get_vector_dim_y_scalar_per_vector().template at<1>();
109 
110  // using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
111  // using vector_t = typename vector_type_t::type;
113 
114  private:
115  static constexpr auto scalars_per_access_ = [] {
116  constexpr auto scalars_per_access_arr = generate_array(
117  [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
118 
120  constexpr auto NDimY_ = NDimY;
121 
122  return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
123  }();
124 
125  static constexpr auto get_space_filling_curve()
126  {
127  constexpr auto tile_dstr = TileDstr{};
128 
129  constexpr auto thread_tensor_lengths_ys =
130  to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
131 
132  // FIXME: need logic to judge dim access order
133  using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
134 
135  return space_filling_curve<decltype(thread_tensor_lengths_ys),
136  DimAccessOrder,
137  decltype(scalars_per_access_)>{};
138  }
139 
140  public:
141  using SFC_Ys = decltype(get_space_filling_curve());
142 
143  static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
144 
145  static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
146  static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
147  };
148 
150 
151  CK_TILE_DEVICE constexpr tile_scatter_gather() = default;
152 
153  CK_TILE_DEVICE constexpr tile_scatter_gather(const BottomTensorView& bottom_tensor_view,
154  const WindowLengths& window_lengths,
155  const BottomTensorIndex& window_origin,
157  const PageIdxArray& page_idx,
158  const ValidArray& valids)
159  : bottom_tensor_view_{bottom_tensor_view},
160  window_lengths_{window_lengths},
161  window_origin_{window_origin},
163  page_idx_{page_idx},
164  valids_{valids},
166  {
167 #if 0 // debug
168  // TODO: this use more register for FA, but less register for GEMM
169  // need investigation
170  // only support warp-tile and block-tile
171  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
172 
173  WindowAdaptorCoord window_adaptor_thread_coord_tmp;
174 
175  if constexpr(NDimP == 1)
176  {
177  window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
179  }
180  else if constexpr(NDimP == 2)
181  {
182  window_adaptor_thread_coord_tmp =
184  AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
185  }
186 #else
187  // TODO: this use less register for FA, but more register for GEMM
188  // need investigation
189  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
192 #endif
193 
194  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
195  window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
196  bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
197  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
198  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
199 
200  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
201  // future load/store() calls (might allocate more registers)
202  using Traits = load_store_traits;
203  using SFC_Ys = typename Traits::SFC_Ys;
204 
205  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
206  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
207  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
208 
209  constexpr auto idx_diff_ys =
210  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
211 
212  constexpr auto idx_diff_ps_ys = container_concat(
213  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
214 
216  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
217 
218  pre_computed_coords_(iCoord) =
219  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
220  });
221  }
222 
224 
226  {
227  return TileDstr::is_static();
228  }
229 
230  CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
231 
232  CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
233 
234  CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
235 
236  CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
237 
238  CK_TILE_DEVICE constexpr void
239  set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
240  {
241  bottom_tensor_view_.buf_.p_data_ = data;
242  }
243 
244  // move thread's window adaptor coordinate and bottom tensor coordinate
245  // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
246  template <typename ATopIndex>
248  WindowAdaptorCoord& window_adaptor_thread_coord,
249  BottomTensorCoord& bottom_tensor_thread_coord,
250  const ATopIndex& idx_diff_adaptor_top) const
251  {
252  array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
253 
254  move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
255  window_adaptor_thread_coord,
256  idx_diff_adaptor_top,
257  idx_diff_adaptor_bottom);
258 
259  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
260  bottom_tensor_thread_coord,
261  idx_diff_adaptor_bottom);
262  }
263 
264  // return vector dimension among [y0, y1, ...]
266  {
267  // bottom tensor top dimension vector lengths and strides
268  const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
269  BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
270 
271  // window vector lengths/strides
272  const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
273  const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
274 
275  // window adaptor [p0, p1, ..., y0, y1, ...]
276  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
277  -1};
278  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
279  -1};
280 
281  constexpr auto window_adaptor_bottom_dims =
282  WindowAdaptor::get_bottom_dimension_hidden_ids();
283 
284  set_container_subset(window_adaptor_vector_lengths,
285  window_adaptor_bottom_dims,
286  window_adaptor_bottom_dim_vector_lengths);
287  set_container_subset(window_adaptor_vector_strides,
288  window_adaptor_bottom_dims,
289  window_adaptor_bottom_dim_vector_strides);
290 
291  const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
292  WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
293  window_adaptor_vector_lengths, window_adaptor_vector_strides);
294 
295  // [y0, y1, ...]
296  constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
298  1>::type{};
299 
300  return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
301  get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
302  }
303 
305 
306  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
309  {
310  constexpr auto tile_dstr = TileDstr{};
311  auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
312  load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
313  return dst_tensor;
314  }
315 
316  template <typename DistributedTensor,
317  index_t i_access_unsupport_ = -1,
318  bool oob_conditional_check = true>
319  CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
322  {
323  using Traits = load_store_traits;
324  using vector_t = typename Traits::vector_t;
325  using SFC_Ys = typename Traits::SFC_Ys;
326 
327  constexpr auto tile_dstr = TileDstr{};
328 
329  // loop over thread tensor space [y0, y1, ...]
330  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
332  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
333  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
334 
335  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
336  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
337 
338  // data index [y0, y1, ...]
339  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
340  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
341  const auto page_offset = page_idx_[idx_gather];
342 
343  // read from bottom tensor
344  const vector_t vec_value = [&]() {
345  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
346  {
347  return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
348  bottom_tensor_thread_coord,
349  page_offset,
350  bool_constant<oob_conditional_check>{});
351  }
352  else
353  {
354  return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
355  bottom_tensor_thread_coord,
356  page_offset,
357  valids_[idx_gather],
358  bool_constant<oob_conditional_check>{});
359  }
360  }();
361 #if 1
362  // write into distributed tensor
363  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
364  constexpr auto idx_ys = generate_tuple(
365  [&](auto jj) {
366  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
367  : idx_ys_start[jj];
368  },
369  number<NDimY>{});
370 
371  constexpr index_t d =
372  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
373  Traits::PackedSize;
374 
375  dst_tensor.get_thread_buffer().template at<d>() =
376  vec_value.template get_as<DataType>()[j / Traits::PackedSize];
377  });
378 #else
379  constexpr index_t d =
380  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
381  static_assert(d % Traits::ScalarPerVector == 0);
382 
383  dst_tensor.get_thread_buffer().template get_as<vector_t>()(
384  number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
385 #endif
386  // move thread coordinate
387  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
388  {
389  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
390 
391  constexpr auto forward_step_scatter = generate_tuple(
392  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
393  number<NDimY>{});
394 
395  constexpr auto idx_diff_ps_ys = container_concat(
396  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
397  forward_step_scatter);
398 
400  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
401  }
402  });
403  });
404  }
405 
406  template <typename LdsTileWindow_,
407  index_t i_access_unsupport_ = -1,
408  bool oob_conditional_check = true>
409  CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
412  {
413  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
414  using LdsDataType = typename LdsTileWindow::DataType;
415  using Traits = load_store_traits;
416  using vector_t = typename Traits::vector_t;
417  using SFC_Ys = typename Traits::SFC_Ys;
418 
419  constexpr auto tile_dstr = TileDstr{};
420 
421  // Precompute invariant values outside loops
422  const auto window_origin = lds_tile.get_window_origin();
423  const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
424  const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
425  auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
426 
427  // loop over thread tensor space [y0, y1, ...]
428  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
430  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
431  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
432 
433  auto lds_window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
434  auto lds_bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
435 
436  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
437  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
438 
439  // Use precomputed window origin
440  auto lds_bottom_tensor_thread_idx =
441  window_origin + lds_window_adaptor_thread_coord.get_bottom_index();
442  // Use precomputed tensor descriptor
443  const auto lds_coord =
444  make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
445  // Calculate SMEM address using base pointer
446  CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
447 
448  // data index [y0, y1, ...]
449  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
450  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
451  const auto page_offset = page_idx_[idx_gather];
452 
453  // merge page_offset into bottom_coord
454  auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
455  mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
456 
457  // read from bottom tensor
458  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
459  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
460  smem,
461  mixed_bottom_thread_coord,
462  number<0>{},
463  bool_constant<oob_conditional_check>{});
464  else
465  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
466  smem,
467  mixed_bottom_thread_coord,
468  number<0>{},
469  valids_[idx_gather],
470  bool_constant<oob_conditional_check>{});
471 
472  // move thread coordinate
473  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
474  {
475  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
476 
477  constexpr auto forward_step_scatter = generate_tuple(
478  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
479  number<NDimY>{});
480 
481  constexpr auto idx_diff_ps_ys = container_concat(
482  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
483  forward_step_scatter);
484  // lds_diff doesn't need to mask the difference of the gather-dim.
485  constexpr auto lds_idx_diff_ps_ys = container_concat(
486  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
487  idx_diff_ys);
488 
490  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
492  lds_window_adaptor_thread_coord,
493  lds_bottom_tensor_thread_coord,
494  lds_idx_diff_ps_ys);
495  }
496  });
497  });
498  }
499 
500  // TODO: currently async load only implemented in inline asm
501  template <typename LdsTileWindow_,
502  index_t i_access_unsupport_ = -1,
503  bool oob_conditional_check = true,
504  bool pre_nop = false>
505  CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
508  bool_constant<pre_nop> = {}) const
509  {
510  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
511  // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
512  using LdsDataType = typename LdsTileWindow::DataType;
513  // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
514 
515  // issues * warps * lanes
516  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
517 
518  const index_t size_per_buf =
519  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
520  make_tuple(number<0>{}, number<0>{}, number<0>{})) *
521  sizeof(LdsDataType);
522 
523  const index_t size_per_wave =
524  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
525  make_tuple(number<0>{}, number<1>{}, number<0>{})) *
526  sizeof(LdsDataType) -
527  size_per_buf;
528 
529  const index_t size_per_issue =
530  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
531  make_tuple(number<1>{}, number<0>{}, number<0>{})) *
532  sizeof(LdsDataType) -
533  size_per_buf;
534 
535  const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
537  amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
538 
539  using Traits = load_store_traits;
540 
541  // using vector_type_t = typename Traits::vector_type_t;
542  using vector_t = typename Traits::vector_t;
543  using SFC_Ys = typename Traits::SFC_Ys;
544 
545  LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
546 
547  // loop over thread tensor space [y0, y1, ...]
548  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
550  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
551  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
552 
553  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
554  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
555  constexpr auto pre_nop_ = [&]() {
556  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
557  return bool_constant<true>{};
558  else
559  return bool_constant<false>{};
560  }();
561 
562  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
563  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
564  const auto page_offset = page_idx_[idx_gather];
565 
566  // read from bottom tensor
567  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
568  {
569  get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
570  smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
571  }
572  else
573  {
574  get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
575  smem,
576  bottom_tensor_thread_coord,
577  page_offset,
578  valids_[idx_gather],
579  0,
580  pre_nop_);
581  }
582 
583  // move thread coordinate
584  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
585  {
586  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
587 
588  constexpr auto forward_step_scatter = generate_tuple(
589  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
590  number<NDimY>{});
591 
592  constexpr auto idx_diff_ps_ys = container_concat(
593  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
594  forward_step_scatter);
595 
597  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
598 
599  m0_inc_with_memory(size_per_issue);
600  }
601  });
602  });
603  }
604 
605  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
609  {
610  using Traits = load_store_traits;
611 
612  // using vector_type_t = typename Traits::vector_type_t;
613  using vector_t = typename Traits::vector_t;
614  using SFC_Ys = typename Traits::SFC_Ys;
615 
616  constexpr auto tile_dstr = TileDstr{};
617 
618  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
619  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
620  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
621 
622  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
623  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
624 
625  // data index [y0, y1, ...]
626  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
627  constexpr auto idx_gather = idx_ys_start[number<0>{}];
628  const auto page_offset = page_idx_[idx_gather];
629 
630  // read from distributed tensor
631  vector_t vec_value;
632 
633  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
634  constexpr auto idx_ys = generate_tuple(
635  [&](auto jj) {
636  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
637  : idx_ys_start[jj];
638  },
639  number<NDimY>{});
640 
641  constexpr index_t d =
642  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
643  Traits::PackedSize;
644 
645  vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
646  dstr_tensor.get_thread_buffer().template at<d>();
647  });
648 
649  // write into bottom tensor
650  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
651  {
652  get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
653  bottom_tensor_thread_coord,
654  page_offset,
655  vec_value,
656  bool_constant<oob_conditional_check>{});
657  }
658  else
659  {
660  get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
661  bottom_tensor_thread_coord,
662  page_offset,
663  valids_[idx_gather],
664  vec_value,
665  bool_constant<oob_conditional_check>{});
666  }
667 
668  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
669  {
670  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
671 
672  constexpr auto forward_step_scatter = generate_tuple(
673  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
674  number<NDimY>{});
675 
676  constexpr auto idx_diff_ps_ys = container_concat(
677  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
678  forward_step_scatter);
679 
681  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
682  }
683  });
684  });
685  }
686 
687  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
691  {
692  using Traits = load_store_traits;
693 
694  // using vector_type_t = typename Traits::vector_type_t;
695  using vector_t = typename Traits::vector_t;
696  using SFC_Ys = typename Traits::SFC_Ys;
697 
698  constexpr auto tile_dstr = TileDstr{};
699  // printf("off %d\n", page_idx_[I0]);
700  // loop over thread tensor space [y0, y1, ...]
701  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
702  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
703  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
704 
705  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
706  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
707 
708  // data index [y0, y1, ...]
709  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
710  constexpr auto idx_gather = idx_ys_start[number<0>{}];
711  const auto page_offset = page_idx_[idx_gather];
712 
713  // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
714  // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
715 
716  // read from distributed tensor
717  // vector_type_t vec;
718  vector_t vec_value;
719 
720  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
721  constexpr auto idx_ys = generate_tuple(
722  [&](auto jj) {
723  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
724  : idx_ys_start[jj];
725  },
726  number<NDimY>{});
727 
728  constexpr index_t d =
729  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
730  Traits::PackedSize;
731  // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
732  vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
733  dstr_tensor.get_thread_buffer().template at<d>();
734  });
735 
736  // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
737 
738  // write into bottom tensor
739  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
740  {
741  get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
742  bottom_tensor_thread_coord,
743  page_offset,
744  vec_value,
745  bool_constant<oob_conditional_check>{});
746  }
747  else
748  {
749  get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
750  bottom_tensor_thread_coord,
751  page_offset,
752  valids_[idx_gather],
753  vec_value,
754  bool_constant<oob_conditional_check>{});
755  }
756 
757  // printf("coord_offset:%d, scatter_offset:%d \n",
758  // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
759  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
760  {
761  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
762 
763  constexpr auto forward_step_scatter = generate_tuple(
764  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
765  number<NDimY>{});
766 
767  constexpr auto idx_diff_ps_ys = container_concat(
768  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
769  forward_step_scatter);
770 
772  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
773  }
774  });
775  });
776  }
777 
778  // move thread's botom tensor coordiante
779  // [x0', x1', ... ] ==> [offset]
780  // also move window-origin
782  {
783  window_origin_ += step;
784  BottomTensorIndex step_new = step;
785  step_new(HsGatherDim) = 0;
786  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
787  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
788  pre_computed_coords_(iCoord)(I1),
789  step_new);
790  });
791  }
792 
793  CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
794 
795  CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
796  {
797  if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
798  {
799  valids_ = new_valids;
800  }
801  }
802 
804  const ValidArray& new_valids)
805  {
806  update_page_idx(new_idx);
807  update_valids(new_valids);
808  }
809 
810  CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
811  {
812  window_origin_ = new_window_origin;
813 
814 #if 0 // debug
815  // TODO: this use more register for FA, but less register for GEMM
816  // need investigation
817  // only support warp-tile and block-tile
818  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
819 
820  WindowAdaptorCoord window_adaptor_thread_coord_tmp;
821 
822  if constexpr(NDimP == 1)
823  {
824  window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
825  tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
826  }
827  else if constexpr(NDimP == 2)
828  {
829  window_adaptor_thread_coord_tmp =
830  make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
831  AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
832  }
833 #else
834  // TODO: this use less register for FA, but more register for GEMM
835  // need investigation
836  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
837  tile_dstr_.get_ps_ys_to_xs_adaptor(),
839 #endif
840 
841  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
842  window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
843 
844  bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
845  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
846  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
847 
848  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
849  // future load/store() calls (might allocate more registers)
850  using Traits = load_store_traits;
851  using SFC_Ys = typename Traits::SFC_Ys;
852 
853  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
854  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
855  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
856 
857  constexpr auto idx_diff_ys =
858  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
859 
860  constexpr auto idx_diff_ps_ys = container_concat(
861  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
862 
864  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
865 
866  pre_computed_coords_(iCoord) =
867  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
868  });
869  }
870 
872 
873  // this is the bottom tensor view
874  // [x0', x1', ...] ==> [offset]
876 
877  //
879 
880  // origin ([x0', x1', ...]) of window on bottom tensor
882 
883  // Tile tensor distribution, which contains:
884  // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
885  // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
887 
890 
891  // this contains:
892  // per-thread coordinate for window adaptor
893  // per-thread coordinate for bottom tensor
895 };
896 
897 // TODO: use strategy
898 template <typename TensorView_,
899  typename WindowLengths_,
900  typename StaticTileDistribution_,
901  typename StaticPageIndexArray_,
902  index_t HsGatherDim = 0,
903  index_t NumCoord = 1>
904 CK_TILE_DEVICE constexpr auto
906  const WindowLengths_& window_lengths,
907  const multi_index<TensorView_::get_num_of_dimension()>& origin,
908  const StaticTileDistribution_& tile_distribution,
909  const StaticPageIndexArray_& page_idx,
910  number<HsGatherDim> = {},
911  number<NumCoord> = {})
912 {
913  return tile_scatter_gather<remove_cvref_t<TensorView_>,
914  remove_cvref_t<WindowLengths_>,
915  remove_cvref_t<StaticTileDistribution_>,
916  remove_cvref_t<StaticPageIndexArray_>,
917  std::nullptr_t,
918  HsGatherDim,
919  NumCoord>{
920  tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
921 }
922 
923 template <typename TensorView,
924  typename WindowLengths,
925  typename StaticTileDistribution,
926  typename StaticPageIndexArray,
927  index_t HsGatherDim>
930  const multi_index<TensorView::get_num_of_dimension()>& origin,
931  const StaticTileDistribution& tile_distribution,
932  const StaticPageIndexArray& page_idx,
933  number<HsGatherDim> = {})
934 {
936  tile_window.get_window_lengths(),
937  origin,
938  tile_distribution,
939  page_idx,
940  number<HsGatherDim>{});
941 }
942 
943 template <typename TensorView,
944  typename WindowLengths,
945  typename StaticTileDistribution,
946  typename StaticPageIndexArray,
947  index_t HsGatherDim>
950  const StaticTileDistribution& tile_distribution,
951  const StaticPageIndexArray& page_idx,
952  number<HsGatherDim> = {})
953 {
955  tile_window.get_window_lengths(),
956  tile_window.get_window_origin(),
957  tile_distribution,
958  page_idx,
959  number<HsGatherDim>{});
960 }
961 
962 template <typename TensorView_,
963  typename WindowLengths_,
964  typename StaticTileDistribution_,
965  typename StaticPageIndexArray_,
966  typename StaticValidArray_,
967  index_t HsGatherDim = 0,
968  index_t NumCoord = 1>
969 CK_TILE_DEVICE constexpr auto
971  const WindowLengths_& window_lengths,
972  const multi_index<TensorView_::get_num_of_dimension()>& origin,
973  const StaticTileDistribution_& tile_distribution,
974  const StaticPageIndexArray_& page_idx,
975  const StaticValidArray_& valids,
976  number<HsGatherDim> = {},
977  number<NumCoord> = {})
978 {
979  return tile_scatter_gather<remove_cvref_t<TensorView_>,
980  remove_cvref_t<WindowLengths_>,
981  remove_cvref_t<StaticTileDistribution_>,
982  remove_cvref_t<StaticPageIndexArray_>,
983  remove_cvref_t<StaticValidArray_>,
984  HsGatherDim,
985  NumCoord>{
986  tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
987 }
988 
989 template <typename TensorView,
990  typename WindowLengths,
991  typename StaticTileDistribution,
992  typename StaticPageIndexArray,
993  typename StaticValidArray,
994  index_t HsGatherDim>
997  const multi_index<TensorView::get_num_of_dimension()>& origin,
998  const StaticTileDistribution& tile_distribution,
999  const StaticPageIndexArray& page_idx,
1000  const StaticValidArray& valids,
1001  number<HsGatherDim> = {})
1002 {
1003  return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
1004  tile_window.get_window_lengths(),
1005  origin,
1006  tile_distribution,
1007  page_idx,
1008  valids,
1009  number<HsGatherDim>{});
1010 }
1011 
1012 template <typename TensorView,
1013  typename WindowLengths,
1014  typename StaticTileDistribution,
1015  typename StaticPageIndexArray,
1016  typename StaticValidArray,
1017  index_t HsGatherDim>
1020  const StaticTileDistribution& tile_distribution,
1021  const StaticPageIndexArray& page_idx,
1022  const StaticValidArray& valids,
1023  number<HsGatherDim> = {})
1024 {
1025  return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
1026  tile_window.get_window_lengths(),
1027  tile_window.get_window_origin(),
1028  tile_distribution,
1029  page_idx,
1030  valids,
1031  number<HsGatherDim>{});
1032 }
1033 
1034 template <typename NewTensorView_,
1035  typename OldTensorView_,
1036  typename WindowLengths_,
1037  typename StaticTileDistribution_,
1038  typename StaticPageIndexArray_,
1039  typename StaticValidArray_,
1040  index_t HsGatherDim = 0,
1041  index_t NumCoord = 1>
1042 CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_& new_tensor_view,
1043  const tile_scatter_gather<OldTensorView_,
1044  WindowLengths_,
1045  StaticTileDistribution_,
1046  StaticPageIndexArray_,
1047  StaticValidArray_,
1048  HsGatherDim,
1049  NumCoord>& tile_window)
1050 {
1051  return make_tile_scatter_gather(new_tensor_view,
1052  tile_window.window_lengths_,
1053  tile_window.window_origin_,
1054  tile_window.tile_dstr_,
1055  tile_window.page_idx_,
1056  tile_window.valids_);
1057 }
1058 
1059 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_LDS_ADDR
Definition: config.hpp:62
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_ &new_tensor_view, const tile_scatter_gather< OldTensorView_, WindowLengths_, StaticTileDistribution_, StaticPageIndexArray_, StaticValidArray_, HsGatherDim, NumCoord > &tile_window)
Definition: tile_scatter_gather.hpp:1042
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:56
constant< b > bool_constant
Definition: integral_constant.hpp:43
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1126
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1066
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:15
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:905
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:98
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
Definition: sequence.hpp:298
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:313
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Definition: numeric.hpp:81
Definition: space_filling_curve.hpp:20
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: debug.hpp:27
Definition: tile_distribution.hpp:70
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:124
Definition: tile_scatter_gather.hpp:81
static constexpr index_t PackedSize
Definition: tile_scatter_gather.hpp:104
static constexpr index_t NumAccess
Definition: tile_scatter_gather.hpp:143
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_scatter_gather.hpp:141
static constexpr index_t VectorDimY
Definition: tile_scatter_gather.hpp:106
static constexpr index_t ScalarPerVector
Definition: tile_scatter_gather.hpp:107
This class provides tile (windowed) view and access to the device memory.
Definition: tile_scatter_gather.hpp:41
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_scatter_gather.hpp:781
static constexpr index_t NumAccessPerCoord
Definition: tile_scatter_gather.hpp:149
static constexpr auto I1
Definition: tile_scatter_gather.hpp:59
constexpr CK_TILE_DEVICE tile_scatter_gather(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution, const PageIdxArray &page_idx, const ValidArray &valids)
Definition: tile_scatter_gather.hpp:153
BottomTensorIndex window_origin_
Definition: tile_scatter_gather.hpp:881
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:307
WindowLengths window_lengths_
Definition: tile_scatter_gather.hpp:878
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_scatter_gather.hpp:232
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_scatter_gather.hpp:304
static constexpr index_t NDimBottomTensor
Definition: tile_scatter_gather.hpp:53
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_scatter_gather.hpp:265
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_scatter_gather.hpp:72
PageIdxArray page_idx_
Definition: tile_scatter_gather.hpp:888
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:409
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_scatter_gather.hpp:43
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_scatter_gather.hpp:810
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_scatter_gather.hpp:894
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_scatter_gather.hpp:236
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_scatter_gather.hpp:44
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_scatter_gather.hpp:247
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:319
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:688
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray &new_idx, const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:803
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_scatter_gather.hpp:48
TileDstr tile_dstr_
Definition: tile_scatter_gather.hpp:886
ValidArray valids_
Definition: tile_scatter_gather.hpp:889
static constexpr index_t NDimY
Definition: tile_scatter_gather.hpp:56
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_scatter_gather.hpp:50
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_scatter_gather.hpp:52
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_scatter_gather.hpp:225
remove_cvref_t< StaticValidArray_ > ValidArray
Definition: tile_scatter_gather.hpp:46
static constexpr index_t NDimP
Definition: tile_scatter_gather.hpp:55
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_scatter_gather.hpp:42
constexpr CK_TILE_DEVICE tile_scatter_gather()=default
CK_TILE_DEVICE void update(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:606
remove_cvref_t< StaticPageIndexArray_ > PageIdxArray
Definition: tile_scatter_gather.hpp:45
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_scatter_gather.hpp:230
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_scatter_gather.hpp:871
static constexpr auto I0
Definition: tile_scatter_gather.hpp:58
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_scatter_gather.hpp:78
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_scatter_gather.hpp:234
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_scatter_gather.hpp:47
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_scatter_gather.hpp:75
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_scatter_gather.hpp:239
BottomTensorView bottom_tensor_view_
Definition: tile_scatter_gather.hpp:875
CK_TILE_DEVICE void update_valids(const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:795
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_scatter_gather.hpp:505
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_scatter_gather.hpp:71
CK_TILE_DEVICE void update_page_idx(const PageIdxArray &new_idx)
Definition: tile_scatter_gather.hpp:793
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_scatter_gather.hpp:223
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
BottomTensorIndex window_origin_
Definition: tile_window_base.hpp:79
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
WindowLengths window_lengths_
Definition: tile_window_base.hpp:81
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1195
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10