/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File
tile_scatter_gather.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
18 
19 namespace ck_tile {
20 
32 template <typename BottomTensorView_,
33  typename WindowLengths_,
34  typename StaticTileDistribution_,
35  typename StaticPageIndexArray_,
36  typename StaticValidArray_,
37  index_t HsGatherDim = 0,
38  index_t NumCoord = 1,
39  index_t YsGatherDim = 0>
41 {
47  using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
48  using BottomTensorDesc = typename BottomTensorView::TensorDesc;
49 
51 
52  static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
53  static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
54 
55  static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
56  static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
57 
58  static constexpr auto I0 = number<0>{};
59  static constexpr auto I1 = number<1>{};
60  static_assert(NumCoord == 1);
61 
62  // TODO: check WindowLengths and StaticTileDistribution are consistent
63 
65  "wrong! lengths should be static");
66  static_assert(TileDstr::is_static(), "wrong!");
67 
68  static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
69  "wrong! inconsistent # of diemsnions");
70 
73 
76 
79 
81  {
82  private:
83  static constexpr auto get_vector_dim_y_scalar_per_vector()
84  {
85  const auto [ys_vector_lengths, ys_vector_strides] =
87 
88  index_t VectorDimY_ = 0;
89  index_t ScalarPerVector_ = 1;
90 
91  for(index_t i = 0; i < NDimY; ++i)
92  {
93  if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
94  {
95  ScalarPerVector_ = ys_vector_lengths[i];
96  VectorDimY_ = i;
97  }
98  }
99 
100  return make_tuple(VectorDimY_, ScalarPerVector_);
101  }
102 
103  public:
104  static constexpr index_t PackedSize =
106  static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
107  static constexpr index_t ScalarPerVector =
108  get_vector_dim_y_scalar_per_vector().template at<1>();
109 
110  // using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
111  // using vector_t = typename vector_type_t::type;
113 
114  private:
115  static constexpr auto scalars_per_access_ = [] {
116  constexpr auto scalars_per_access_arr = generate_array(
117  [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
118 
120  constexpr auto NDimY_ = NDimY;
121 
122  return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
123  }();
124 
125  static constexpr auto get_space_filling_curve()
126  {
127  constexpr auto tile_dstr = TileDstr{};
128 
129  constexpr auto thread_tensor_lengths_ys =
130  to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
131 
132  // FIXME: need logic to judge dim access order
133  using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
134 
135  return space_filling_curve<decltype(thread_tensor_lengths_ys),
136  DimAccessOrder,
137  decltype(scalars_per_access_)>{};
138  }
139 
140  public:
141  using SFC_Ys = decltype(get_space_filling_curve());
142 
143  static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
144 
145  static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
146  static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
147  };
148 
150 
151  CK_TILE_DEVICE constexpr tile_scatter_gather() = default;
152 
153  CK_TILE_DEVICE constexpr tile_scatter_gather(const BottomTensorView& bottom_tensor_view,
154  const WindowLengths& window_lengths,
155  const BottomTensorIndex& window_origin,
157  const PageIdxArray& page_idx,
158  const ValidArray& valids)
159  : bottom_tensor_view_{bottom_tensor_view},
160  window_lengths_{window_lengths},
161  window_origin_{window_origin},
163  page_idx_{page_idx},
164  valids_{valids},
166  {
167 #if 0 // debug
168  // TODO: this use more register for FA, but less register for GEMM
169  // need investigation
170  // only support warp-tile and block-tile
171  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
172 
173  WindowAdaptorCoord window_adaptor_thread_coord_tmp;
174 
175  if constexpr(NDimP == 1)
176  {
177  window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
179  }
180  else if constexpr(NDimP == 2)
181  {
182  window_adaptor_thread_coord_tmp =
184  AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
185  }
186 #else
187  // TODO: this use less register for FA, but more register for GEMM
188  // need investigation
189  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
192 #endif
193 
194  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
195  window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
196  bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
197  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
198  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
199 
200  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
201  // future load/store() calls (might allocate more registers)
202  using Traits = load_store_traits;
203  using SFC_Ys = typename Traits::SFC_Ys;
204 
205  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
206  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
207  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
208 
209  constexpr auto idx_diff_ys =
210  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
211 
212  constexpr auto idx_diff_ps_ys = container_concat(
213  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
214 
216  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
217 
218  pre_computed_coords_(iCoord) =
219  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
220  });
221  if constexpr(BottomTensorView::buffer_view::get_address_space() ==
222  address_space_enum::global)
223  {
224  auto partition_index = get_partition_index(tile_distribution);
225 
226  auto use_lane_id_0 = partition_index;
227  use_lane_id_0[1] = 0;
228  const auto window_adaptor_thread_coord_tmp_warp = make_tensor_adaptor_coordinate(
230  container_concat(use_lane_id_0, array<index_t, NDimY>{0}));
231 
232  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp_warp =
233  window_origin + window_adaptor_thread_coord_tmp_warp.get_bottom_index();
234  bottom_tensor_thread_origin_idx_tmp_warp(HsGatherDim) = 0;
235  const auto bottom_tensor_thread_coord_tmp_warp =
236  make_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
237  bottom_tensor_thread_origin_idx_tmp_warp);
238 
239  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
240  // future load/store() calls (might allocate more registers)
241  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
242  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp_warp;
243  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp_warp;
244 
245  constexpr auto idx_diff_ys =
246  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
247 
248  constexpr auto idx_diff_ps_ys = container_concat(
249  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
250  idx_diff_ys);
251 
253  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
254 
255  pre_computed_warp_coords_(iCoord) =
256  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
257  });
258  }
259  }
260 
262 
264  {
265  return TileDstr::is_static();
266  }
267 
268  CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
269 
270  CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
271 
272  CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
273 
274  CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
275 
276  CK_TILE_DEVICE constexpr void
277  set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
278  {
279  bottom_tensor_view_.buf_.p_data_ = data;
280  }
281 
282  // move thread's window adaptor coordinate and bottom tensor coordinate
283  // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
284  template <typename ATopIndex>
286  WindowAdaptorCoord& window_adaptor_thread_coord,
287  BottomTensorCoord& bottom_tensor_thread_coord,
288  const ATopIndex& idx_diff_adaptor_top) const
289  {
290  array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
291 
292  move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
293  window_adaptor_thread_coord,
294  idx_diff_adaptor_top,
295  idx_diff_adaptor_bottom);
296 
297  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
298  bottom_tensor_thread_coord,
299  idx_diff_adaptor_bottom);
300  }
301 
302  // return vector dimension among [y0, y1, ...]
304  {
305  // bottom tensor top dimension vector lengths and strides
306  const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
307  BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
308 
309  // window vector lengths/strides
310  const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
311  const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
312 
313  // window adaptor [p0, p1, ..., y0, y1, ...]
314  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
315  -1};
316  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
317  -1};
318 
319  constexpr auto window_adaptor_bottom_dims =
320  WindowAdaptor::get_bottom_dimension_hidden_ids();
321 
322  set_container_subset(window_adaptor_vector_lengths,
323  window_adaptor_bottom_dims,
324  window_adaptor_bottom_dim_vector_lengths);
325  set_container_subset(window_adaptor_vector_strides,
326  window_adaptor_bottom_dims,
327  window_adaptor_bottom_dim_vector_strides);
328 
329  const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
330  WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
331  window_adaptor_vector_lengths, window_adaptor_vector_strides);
332 
333  // [y0, y1, ...]
334  constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
336  1>::type{};
337 
338  return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
339  get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
340  }
341 
343 
344  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
347  {
348  constexpr auto tile_dstr = TileDstr{};
349  auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
350  load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
351  return dst_tensor;
352  }
353 
354  template <typename DistributedTensor,
355  index_t i_access_unsupport_ = -1,
356  bool oob_conditional_check = true>
357  CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
360  {
361  using Traits = load_store_traits;
362  using vector_t = typename Traits::vector_t;
363  using SFC_Ys = typename Traits::SFC_Ys;
364 
365  constexpr auto tile_dstr = TileDstr{};
366 
367  // loop over thread tensor space [y0, y1, ...]
368  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
370  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
371  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
372 
373  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
374  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
375 
376  // data index [y0, y1, ...]
377  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
378  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
379  const auto page_offset = page_idx_[idx_gather];
380 
381  // read from bottom tensor
382  const vector_t vec_value = [&]() {
383  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
384  {
385  return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
386  bottom_tensor_thread_coord,
387  page_offset,
388  bool_constant<oob_conditional_check>{});
389  }
390  else
391  {
392  return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
393  bottom_tensor_thread_coord,
394  page_offset,
395  valids_[idx_gather],
396  bool_constant<oob_conditional_check>{});
397  }
398  }();
399 #if 1
400  // write into distributed tensor
401  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
402  constexpr auto idx_ys = generate_tuple(
403  [&](auto jj) {
404  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
405  : idx_ys_start[jj];
406  },
407  number<NDimY>{});
408 
409  constexpr index_t d =
410  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
411  Traits::PackedSize;
412 
413  dst_tensor.get_thread_buffer().template at<d>() =
414  vec_value.template get_as<DataType>()[j / Traits::PackedSize];
415  });
416 #else
417  constexpr index_t d =
418  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
419  static_assert(d % Traits::ScalarPerVector == 0);
420 
421  dst_tensor.get_thread_buffer().template get_as<vector_t>()(
422  number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
423 #endif
424  // move thread coordinate
425  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
426  {
427  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
428 
429  constexpr auto forward_step_scatter = generate_tuple(
430  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
431  number<NDimY>{});
432 
433  constexpr auto idx_diff_ps_ys = container_concat(
434  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
435  forward_step_scatter);
436 
438  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
439  }
440  });
441  });
442  }
443 
444  template <typename LdsTileWindow_,
445  index_t i_access_unsupport_ = -1,
446  bool oob_conditional_check = true>
447  CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
450  {
451  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
452  using LdsDataType = typename LdsTileWindow::DataType;
453  using Traits = load_store_traits;
454  using vector_t = typename Traits::vector_t;
455  using SFC_Ys = typename Traits::SFC_Ys;
456 
457  constexpr auto tile_dstr = TileDstr{};
458 
459  // Precompute invariant values outside loops
460  const auto window_origin = lds_tile.get_window_origin();
461  const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
462  const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
463  auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
464 
465  // loop over thread tensor space [y0, y1, ...]
466  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
468  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
469  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
470 
471  auto lds_window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
472  auto lds_bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
473 
474  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
475  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
476 
477  // Use precomputed window origin
478  auto lds_bottom_tensor_thread_idx =
479  window_origin + lds_window_adaptor_thread_coord.get_bottom_index();
480  // Use precomputed tensor descriptor
481  const auto lds_coord =
482  make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
483  // Calculate SMEM address using base pointer
484  CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
485 
486  // data index [y0, y1, ...]
487  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
488  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
489  const auto page_offset = page_idx_[idx_gather];
490 
491  // merge page_offset into bottom_coord
492  auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
493  mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
494 
495  // read from bottom tensor
496  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
497  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
498  smem,
499  mixed_bottom_thread_coord,
500  number<0>{},
501  bool_constant<oob_conditional_check>{});
502  else
503  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
504  smem,
505  mixed_bottom_thread_coord,
506  number<0>{},
507  valids_[idx_gather],
508  bool_constant<oob_conditional_check>{});
509 
510  // move thread coordinate
511  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
512  {
513  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
514 
515  constexpr auto forward_step_scatter = generate_tuple(
516  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
517  number<NDimY>{});
518 
519  constexpr auto idx_diff_ps_ys = container_concat(
520  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
521  forward_step_scatter);
522  // lds_diff doesn't need to mask the difference of the gather-dim.
523  constexpr auto lds_idx_diff_ps_ys = container_concat(
524  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
525  idx_diff_ys);
526 
528  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
530  lds_window_adaptor_thread_coord,
531  lds_bottom_tensor_thread_coord,
532  lds_idx_diff_ps_ys);
533  }
534  });
535  });
536  }
537 
538  // TODO: currently async load only implemented in inline asm
539  template <typename LdsTileWindow_,
540  index_t i_access_unsupport_ = -1,
541  bool oob_conditional_check = true,
542  bool pre_nop = false>
543  CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
546  bool_constant<pre_nop> = {}) const
547  {
548  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
549  // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
550  using LdsDataType = typename LdsTileWindow::DataType;
551  // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
552 
553  // issues * warps * lanes
554  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
555 
556  const index_t size_per_buf =
557  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
558  make_tuple(number<0>{}, number<0>{}, number<0>{})) *
559  sizeof(LdsDataType);
560 
561  const index_t size_per_wave =
562  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
563  make_tuple(number<0>{}, number<1>{}, number<0>{})) *
564  sizeof(LdsDataType) -
565  size_per_buf;
566 
567  const index_t size_per_issue =
568  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
569  make_tuple(number<1>{}, number<0>{}, number<0>{})) *
570  sizeof(LdsDataType) -
571  size_per_buf;
572 
573  const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
575  amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
576 
577  using Traits = load_store_traits;
578 
579  // using vector_type_t = typename Traits::vector_type_t;
580  using vector_t = typename Traits::vector_t;
581  using SFC_Ys = typename Traits::SFC_Ys;
582 
583  LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
584 
585  // loop over thread tensor space [y0, y1, ...]
586  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
588  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
589  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
590 
591  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
592  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
593  constexpr auto pre_nop_ = [&]() {
594  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
595  return bool_constant<true>{};
596  else
597  return bool_constant<false>{};
598  }();
599 
600  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
601  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
602  const auto page_offset = page_idx_[idx_gather];
603 
604  // read from bottom tensor
605  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
606  {
607  get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
608  smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
609  }
610  else
611  {
612  get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
613  smem,
614  bottom_tensor_thread_coord,
615  page_offset,
616  valids_[idx_gather],
617  0,
618  pre_nop_);
619  }
620 
621  // move thread coordinate
622  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
623  {
624  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
625 
626  constexpr auto forward_step_scatter = generate_tuple(
627  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
628  number<NDimY>{});
629 
630  constexpr auto idx_diff_ps_ys = container_concat(
631  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
632  forward_step_scatter);
633 
635  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
636 
637  m0_inc_with_memory(size_per_issue);
638  }
639  });
640  });
641  }
642 
643  // TODO: fix with swizzle
644  template <typename LdsTileWindow_,
645  index_t i_access_unsupport_ = -1,
646  bool oob_conditional_check = true,
647  bool static_move_ys = false,
648  typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
650  LdsTileWindow_&& lds_tile,
653  bool_constant<static_move_ys> = {}) const
654  {
655  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
656  using LdsDataType = typename LdsTileWindow::DataType;
657 
658  using Traits = load_store_traits;
659 
660  using vector_t = typename Traits::vector_t;
661  using SFC_Ys = typename Traits::SFC_Ys;
662 
663  // Precompute invariant values outside loops
664  const auto window_origin = lds_tile.get_window_origin();
665  const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
666  const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
667  auto lds_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
668 
669  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
670  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
671  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
672 
673  auto window_adaptor_warp_coord = pre_computed_warp_coords_[iCoord][I0];
674  auto bottom_tensor_warp_coord = pre_computed_warp_coords_[iCoord][I1];
675 
676  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
677  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
678 
679  constexpr auto idx_ys_offset = [&]() {
680  constexpr auto idx_off_ys = SFC_Ys::get_step_between(number<0>{}, iAccess);
681  constexpr auto adapter_ys_offset = make_tensor_adaptor_coordinate(
682  StaticTileDistribution_{}.get_ps_ys_to_xs_adaptor(),
683  container_concat(array<index_t, NDimP>{0},
684  to_array<index_t, idx_off_ys.size()>(idx_off_ys)));
685  return adapter_ys_offset.get_bottom_index();
686  }();
687  const auto lds_ys_offset = [&]() {
688  if constexpr(static_move_ys)
689  {
690  const auto coord_ys_offset =
691  make_tensor_coordinate(tensor_descriptor, idx_ys_offset);
692  return coord_ys_offset.get_offset();
693  }
694  else
695  return 0;
696  }();
697 
698  // Use precomputed window origin & tensor descriptor
699  auto lds_bottom_tensor_thread_idx =
700  window_origin + window_adaptor_warp_coord.get_bottom_index();
701  const auto lds_coord =
702  make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
703 
704  // Calculate SMEM address using base pointer
705  CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr +
706  lds_coord.get_offset() / Traits::PackedSize +
707  lds_ys_offset / Traits::PackedSize;
708 
709  const auto dram_ys_offset = [&]() {
710  if constexpr(static_move_ys)
711  {
712  const auto coord_ys_offset = make_tensor_coordinate(
713  this->get_bottom_tensor_view().get_tensor_descriptor(), idx_ys_offset);
714  return coord_ys_offset.get_offset();
715  }
716  else
717  return 0;
718  }();
719 
720  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
721  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
722  const auto page_offset = page_idx_[idx_gather];
723 
724  auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
725  mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
726 
727  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
728  {
729  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
730  smem,
731  mixed_bottom_thread_coord,
732  offset + dram_ys_offset,
733  bool_constant<oob_conditional_check>{});
734  }
735  else
736  {
737  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
738  smem,
739  mixed_bottom_thread_coord,
740  offset + dram_ys_offset,
741  valids_[idx_gather],
742  bool_constant<oob_conditional_check>{});
743  }
744 
745  // Move thread coordinate if not last access
746  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
747  {
748  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
749 
750  constexpr auto forward_step_scatter = generate_tuple(
751  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
752  number<NDimY>{});
753 
754  constexpr auto idx_diff_ps_ys = container_concat(
755  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
756  forward_step_scatter);
757 
758  if constexpr(!static_move_ys)
760  window_adaptor_thread_coord,
761  bottom_tensor_thread_coord,
762  idx_diff_ps_ys);
763 
764  if constexpr(!static_move_ys)
766  window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
767  }
768  });
769  });
770  }
771 
772  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
776  {
777  using Traits = load_store_traits;
778 
779  // using vector_type_t = typename Traits::vector_type_t;
780  using vector_t = typename Traits::vector_t;
781  using SFC_Ys = typename Traits::SFC_Ys;
782 
783  constexpr auto tile_dstr = TileDstr{};
784 
785  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
786  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
787  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
788 
789  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
790  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
791 
792  // data index [y0, y1, ...]
793  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
794  constexpr auto idx_gather = idx_ys_start[number<0>{}];
795  const auto page_offset = page_idx_[idx_gather];
796 
797  // read from distributed tensor
798  vector_t vec_value;
799 
800  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
801  constexpr auto idx_ys = generate_tuple(
802  [&](auto jj) {
803  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
804  : idx_ys_start[jj];
805  },
806  number<NDimY>{});
807 
808  constexpr index_t d =
809  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
810  Traits::PackedSize;
811 
812  vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
813  dstr_tensor.get_thread_buffer().template at<d>();
814  });
815 
816  // write into bottom tensor
817  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
818  {
819  get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
820  bottom_tensor_thread_coord,
821  page_offset,
822  vec_value,
823  bool_constant<oob_conditional_check>{});
824  }
825  else
826  {
827  get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
828  bottom_tensor_thread_coord,
829  page_offset,
830  valids_[idx_gather],
831  vec_value,
832  bool_constant<oob_conditional_check>{});
833  }
834 
835  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
836  {
837  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
838 
839  constexpr auto forward_step_scatter = generate_tuple(
840  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
841  number<NDimY>{});
842 
843  constexpr auto idx_diff_ps_ys = container_concat(
844  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
845  forward_step_scatter);
846 
848  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
849  }
850  });
851  });
852  }
853 
854  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
858  {
859  using Traits = load_store_traits;
860 
861  // using vector_type_t = typename Traits::vector_type_t;
862  using vector_t = typename Traits::vector_t;
863  using SFC_Ys = typename Traits::SFC_Ys;
864 
865  constexpr auto tile_dstr = TileDstr{};
866  // printf("off %d\n", page_idx_[I0]);
867  // loop over thread tensor space [y0, y1, ...]
868  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
869  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
870  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
871 
872  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
873  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
874 
875  // data index [y0, y1, ...]
876  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
877  constexpr auto idx_gather = idx_ys_start[number<0>{}];
878  const auto page_offset = page_idx_[idx_gather];
879 
880  // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
881  // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
882 
883  // read from distributed tensor
884  // vector_type_t vec;
885  vector_t vec_value;
886 
887  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
888  constexpr auto idx_ys = generate_tuple(
889  [&](auto jj) {
890  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
891  : idx_ys_start[jj];
892  },
893  number<NDimY>{});
894 
895  constexpr index_t d =
896  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
897  Traits::PackedSize;
898  // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
899  vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
900  dstr_tensor.get_thread_buffer().template at<d>();
901  });
902 
903  // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
904 
905  // write into bottom tensor
906  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
907  {
908  get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
909  bottom_tensor_thread_coord,
910  page_offset,
911  vec_value,
912  bool_constant<oob_conditional_check>{});
913  }
914  else
915  {
916  get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
917  bottom_tensor_thread_coord,
918  page_offset,
919  valids_[idx_gather],
920  vec_value,
921  bool_constant<oob_conditional_check>{});
922  }
923 
924  // printf("coord_offset:%d, scatter_offset:%d \n",
925  // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
926  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
927  {
928  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
929 
930  constexpr auto forward_step_scatter = generate_tuple(
931  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
932  number<NDimY>{});
933 
934  constexpr auto idx_diff_ps_ys = container_concat(
935  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
936  forward_step_scatter);
937 
939  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
940  }
941  });
942  });
943  }
944 
945  // move thread's botom tensor coordiante
946  // [x0', x1', ... ] ==> [offset]
947  // also move window-origin
949  {
950  window_origin_ += step;
951  BottomTensorIndex step_new = step;
952  step_new(HsGatherDim) = 0;
953  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
954  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
955  pre_computed_coords_(iCoord)(I1),
956  step_new);
957  });
958  if constexpr(BottomTensorView::buffer_view::get_address_space() ==
959  address_space_enum::global)
960  {
961  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
962  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
963  pre_computed_warp_coords_(iCoord)(I1),
964  step_new);
965  });
966  }
967  }
968 
969  CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
970 
971  CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
972  {
973  if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
974  {
975  valids_ = new_valids;
976  }
977  }
978 
980  const ValidArray& new_valids)
981  {
982  update_page_idx(new_idx);
983  update_valids(new_valids);
984  }
985 
986  CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
987  {
988  window_origin_ = new_window_origin;
989 
990 #if 0 // debug
991  // TODO: this use more register for FA, but less register for GEMM
992  // need investigation
993  // only support warp-tile and block-tile
994  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
995 
996  WindowAdaptorCoord window_adaptor_thread_coord_tmp;
997 
998  if constexpr(NDimP == 1)
999  {
1000  window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
1001  tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
1002  }
1003  else if constexpr(NDimP == 2)
1004  {
1005  window_adaptor_thread_coord_tmp =
1006  make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
1007  AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
1008  }
1009 #else
1010  // TODO: this use less register for FA, but more register for GEMM
1011  // need investigation
1012  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
1013  tile_dstr_.get_ps_ys_to_xs_adaptor(),
1015 #endif
1016 
1017  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
1018  window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
1019 
1020  bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
1021  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
1022  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
1023 
1024  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
1025  // future load/store() calls (might allocate more registers)
1026  using Traits = load_store_traits;
1027  using SFC_Ys = typename Traits::SFC_Ys;
1028 
1029  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
1030  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
1031  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
1032 
1033  constexpr auto idx_diff_ys =
1034  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
1035 
1036  constexpr auto idx_diff_ps_ys = container_concat(
1037  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
1038 
1040  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
1041 
1042  pre_computed_coords_(iCoord) =
1043  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
1044  });
1045  }
1046 
1048 
1049  // this is the bottom tensor view
1050  // [x0', x1', ...] ==> [offset]
1052 
1053  //
1055 
1056  // origin ([x0', x1', ...]) of window on bottom tensor
1058 
1059  // Tile tensor distribution, which contains:
1060  // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
1061  // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
1063 
1066 
1067  // this contains:
1068  // per-thread coordinate for window adaptor
1069  // per-thread coordinate for bottom tensor
1071  std::conditional_t<BottomTensorView::buffer_view::get_address_space() ==
1072  address_space_enum::global,
1074  std::byte>
1076 };
1077 
1078 // TODO: use strategy
1079 template <typename TensorView_,
1080  typename WindowLengths_,
1081  typename StaticTileDistribution_,
1082  typename StaticPageIndexArray_,
1083  index_t HsGatherDim = 0,
1084  index_t NumCoord = 1>
1085 CK_TILE_DEVICE constexpr auto
1087  const WindowLengths_& window_lengths,
1088  const multi_index<TensorView_::get_num_of_dimension()>& origin,
1089  const StaticTileDistribution_& tile_distribution,
1090  const StaticPageIndexArray_& page_idx, // perbytes
1091  number<HsGatherDim> = {},
1092  number<NumCoord> = {})
1093 {
1094  return tile_scatter_gather<remove_cvref_t<TensorView_>,
1095  remove_cvref_t<WindowLengths_>,
1096  remove_cvref_t<StaticTileDistribution_>,
1097  remove_cvref_t<StaticPageIndexArray_>,
1098  std::nullptr_t,
1099  HsGatherDim,
1100  NumCoord>{
1101  tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
1102 }
1103 
1104 template <typename TensorView,
1105  typename WindowLengths,
1106  typename StaticTileDistribution,
1107  typename StaticPageIndexArray,
1108  index_t HsGatherDim>
1111  const multi_index<TensorView::get_num_of_dimension()>& origin,
1112  const StaticTileDistribution& tile_distribution,
1113  const StaticPageIndexArray& page_idx,
1114  number<HsGatherDim> = {})
1115 {
1116  return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
1117  tile_window.get_window_lengths(),
1118  origin,
1119  tile_distribution,
1120  page_idx,
1121  number<HsGatherDim>{});
1122 }
1123 
1124 template <typename TensorView,
1125  typename WindowLengths,
1126  typename StaticTileDistribution,
1127  typename StaticPageIndexArray,
1128  index_t HsGatherDim>
1131  const StaticTileDistribution& tile_distribution,
1132  const StaticPageIndexArray& page_idx,
1133  number<HsGatherDim> = {})
1134 {
1135  return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
1136  tile_window.get_window_lengths(),
1137  tile_window.get_window_origin(),
1138  tile_distribution,
1139  page_idx,
1140  number<HsGatherDim>{});
1141 }
1142 
1143 template <typename TensorView_,
1144  typename WindowLengths_,
1145  typename StaticTileDistribution_,
1146  typename StaticPageIndexArray_,
1147  typename StaticValidArray_,
1148  index_t HsGatherDim = 0,
1149  index_t NumCoord = 1>
1150 CK_TILE_DEVICE constexpr auto
1152  const WindowLengths_& window_lengths,
1153  const multi_index<TensorView_::get_num_of_dimension()>& origin,
1154  const StaticTileDistribution_& tile_distribution,
1155  const StaticPageIndexArray_& page_idx,
1156  const StaticValidArray_& valids,
1157  number<HsGatherDim> = {},
1158  number<NumCoord> = {})
1159 {
1160  return tile_scatter_gather<remove_cvref_t<TensorView_>,
1161  remove_cvref_t<WindowLengths_>,
1162  remove_cvref_t<StaticTileDistribution_>,
1163  remove_cvref_t<StaticPageIndexArray_>,
1164  remove_cvref_t<StaticValidArray_>,
1165  HsGatherDim,
1166  NumCoord>{
1167  tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
1168 }
1169 
1170 template <typename TensorView,
1171  typename WindowLengths,
1172  typename StaticTileDistribution,
1173  typename StaticPageIndexArray,
1174  typename StaticValidArray,
1175  index_t HsGatherDim>
1178  const multi_index<TensorView::get_num_of_dimension()>& origin,
1179  const StaticTileDistribution& tile_distribution,
1180  const StaticPageIndexArray& page_idx,
1181  const StaticValidArray& valids,
1182  number<HsGatherDim> = {})
1183 {
1184  return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
1185  tile_window.get_window_lengths(),
1186  origin,
1187  tile_distribution,
1188  page_idx,
1189  valids,
1190  number<HsGatherDim>{});
1191 }
1192 
1193 template <typename TensorView,
1194  typename WindowLengths,
1195  typename StaticTileDistribution,
1196  typename StaticPageIndexArray,
1197  typename StaticValidArray,
1198  index_t HsGatherDim>
1201  const StaticTileDistribution& tile_distribution,
1202  const StaticPageIndexArray& page_idx,
1203  const StaticValidArray& valids,
1204  number<HsGatherDim> = {})
1205 {
1206  return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
1207  tile_window.get_window_lengths(),
1208  tile_window.get_window_origin(),
1209  tile_distribution,
1210  page_idx,
1211  valids,
1212  number<HsGatherDim>{});
1213 }
1214 
1215 template <typename NewTensorView_,
1216  typename OldTensorView_,
1217  typename WindowLengths_,
1218  typename StaticTileDistribution_,
1219  typename StaticPageIndexArray_,
1220  typename StaticValidArray_,
1221  index_t HsGatherDim = 0,
1222  index_t NumCoord = 1>
1223 CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_& new_tensor_view,
1224  const tile_scatter_gather<OldTensorView_,
1225  WindowLengths_,
1226  StaticTileDistribution_,
1227  StaticPageIndexArray_,
1228  StaticValidArray_,
1229  HsGatherDim,
1230  NumCoord>& tile_window)
1231 {
1232  return make_tile_scatter_gather(new_tensor_view,
1233  tile_window.window_lengths_,
1234  tile_window.window_origin_,
1235  tile_window.tile_dstr_,
1236  tile_window.page_idx_,
1237  tile_window.valids_);
1238 }
1239 
1240 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_LDS_ADDR
Definition: config.hpp:62
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto to_array(const std::vector< X > &x)
Definition: array.hpp:286
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_ &new_tensor_view, const tile_scatter_gather< OldTensorView_, WindowLengths_, StaticTileDistribution_, StaticPageIndexArray_, StaticValidArray_, HsGatherDim, NumCoord > &tile_window)
Definition: tile_scatter_gather.hpp:1223
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:56
constant< b > bool_constant
Definition: integral_constant.hpp:43
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1126
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1066
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:15
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:1086
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:98
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: sequence.hpp:298
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:313
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Definition: numeric.hpp:81
Definition: coordinate_transform.hpp:1392
Definition: space_filling_curve.hpp:20
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: debug.hpp:27
Definition: tile_distribution.hpp:70
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:124
Definition: tile_scatter_gather.hpp:81
static constexpr index_t PackedSize
Definition: tile_scatter_gather.hpp:104
static constexpr index_t NumAccess
Definition: tile_scatter_gather.hpp:143
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_scatter_gather.hpp:141
static constexpr index_t VectorDimY
Definition: tile_scatter_gather.hpp:106
static constexpr index_t ScalarPerVector
Definition: tile_scatter_gather.hpp:107
This class provides tile (windowed) view and access to the device memory.
Definition: tile_scatter_gather.hpp:41
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_scatter_gather.hpp:948
static constexpr index_t NumAccessPerCoord
Definition: tile_scatter_gather.hpp:149
static constexpr auto I1
Definition: tile_scatter_gather.hpp:59
constexpr CK_TILE_DEVICE tile_scatter_gather(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution, const PageIdxArray &page_idx, const ValidArray &valids)
Definition: tile_scatter_gather.hpp:153
BottomTensorIndex window_origin_
Definition: tile_scatter_gather.hpp:1057
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:345
WindowLengths window_lengths_
Definition: tile_scatter_gather.hpp:1054
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_scatter_gather.hpp:270
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_scatter_gather.hpp:342
static constexpr index_t NDimBottomTensor
Definition: tile_scatter_gather.hpp:53
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_scatter_gather.hpp:303
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_scatter_gather.hpp:72
PageIdxArray page_idx_
Definition: tile_scatter_gather.hpp:1064
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:447
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_scatter_gather.hpp:43
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_scatter_gather.hpp:986
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_scatter_gather.hpp:1070
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_scatter_gather.hpp:274
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_scatter_gather.hpp:44
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_scatter_gather.hpp:285
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:357
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:855
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray &new_idx, const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:979
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_scatter_gather.hpp:48
TileDstr tile_dstr_
Definition: tile_scatter_gather.hpp:1062
ValidArray valids_
Definition: tile_scatter_gather.hpp:1065
static constexpr index_t NDimY
Definition: tile_scatter_gather.hpp:56
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_scatter_gather.hpp:50
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_scatter_gather.hpp:52
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_scatter_gather.hpp:263
remove_cvref_t< StaticValidArray_ > ValidArray
Definition: tile_scatter_gather.hpp:46
static constexpr index_t NDimP
Definition: tile_scatter_gather.hpp:55
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_scatter_gather.hpp:42
constexpr CK_TILE_DEVICE tile_scatter_gather()=default
CK_TILE_DEVICE void update(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:773
remove_cvref_t< StaticPageIndexArray_ > PageIdxArray
Definition: tile_scatter_gather.hpp:45
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_scatter_gather.hpp:268
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_scatter_gather.hpp:1047
static constexpr auto I0
Definition: tile_scatter_gather.hpp:58
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_scatter_gather.hpp:78
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_scatter_gather.hpp:272
std::conditional_t< BottomTensorView::buffer_view::get_address_space()==address_space_enum::global, array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord >, std::byte > pre_computed_warp_coords_
Definition: tile_scatter_gather.hpp:1075
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_scatter_gather.hpp:47
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_scatter_gather.hpp:75
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_scatter_gather.hpp:277
BottomTensorView bottom_tensor_view_
Definition: tile_scatter_gather.hpp:1051
CK_TILE_DEVICE void async_load_with_offset(index_t offset, LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< static_move_ys >={}) const
Definition: tile_scatter_gather.hpp:649
CK_TILE_DEVICE void update_valids(const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:971
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_scatter_gather.hpp:543
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_scatter_gather.hpp:71
CK_TILE_DEVICE void update_page_idx(const PageIdxArray &new_idx)
Definition: tile_scatter_gather.hpp:969
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_scatter_gather.hpp:261
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
BottomTensorIndex window_origin_
Definition: tile_window_base.hpp:79
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
WindowLengths window_lengths_
Definition: tile_window_base.hpp:81
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1195
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10