/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File
tile_scatter_gather.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
18 
19 namespace ck_tile {
20 
32 template <typename BottomTensorView_,
33  typename WindowLengths_,
34  typename StaticTileDistribution_,
35  typename StaticPageIndexArray_,
36  typename StaticValidArray_,
37  index_t HsGatherDim = 0,
38  index_t NumCoord = 1,
39  index_t YsGatherDim = 0>
41 {
47  using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
48  using BottomTensorDesc = typename BottomTensorView::TensorDesc;
49 
51 
52  static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
53  static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
54 
55  static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
56  static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
57 
58  static constexpr auto I0 = number<0>{};
59  static constexpr auto I1 = number<1>{};
60  static_assert(NumCoord == 1);
61 
62  // TODO: check WindowLengths and StaticTileDistribution are consistent
63 
65  "wrong! lengths should be static");
66  static_assert(TileDstr::is_static(), "wrong!");
67 
68  static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
69  "wrong! inconsistent # of diemsnions");
70 
73 
76 
79 
81  {
82  private:
83  static constexpr auto get_vector_dim_y_scalar_per_vector()
84  {
85  const auto [ys_vector_lengths, ys_vector_strides] =
87 
88  index_t VectorDimY_ = 0;
89  index_t ScalarPerVector_ = 1;
90 
91  for(index_t i = 0; i < NDimY; ++i)
92  {
93  if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
94  {
95  ScalarPerVector_ = ys_vector_lengths[i];
96  VectorDimY_ = i;
97  }
98  }
99 
100  return make_tuple(VectorDimY_, ScalarPerVector_);
101  }
102 
103  public:
104  static constexpr index_t PackedSize =
106  static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
107  static constexpr index_t ScalarPerVector =
108  get_vector_dim_y_scalar_per_vector().template at<1>();
109 
110  // using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
111  // using vector_t = typename vector_type_t::type;
113 
114  private:
115  static constexpr auto scalars_per_access_ = [] {
116  constexpr auto scalars_per_access_arr = generate_array(
117  [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
118 
120  constexpr auto NDimY_ = NDimY;
121 
122  return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
123  }();
124 
125  static constexpr auto get_space_filling_curve()
126  {
127  constexpr auto tile_dstr = TileDstr{};
128 
129  constexpr auto thread_tensor_lengths_ys =
130  to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
131 
132  // FIXME: need logic to judge dim access order
133  using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
134 
135  return space_filling_curve<decltype(thread_tensor_lengths_ys),
136  DimAccessOrder,
137  decltype(scalars_per_access_)>{};
138  }
139 
140  public:
141  using SFC_Ys = decltype(get_space_filling_curve());
142 
143  static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
144 
145  static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
146  static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
147  };
148 
150 
151  CK_TILE_DEVICE constexpr tile_scatter_gather() = default;
152 
153  CK_TILE_DEVICE constexpr tile_scatter_gather(const BottomTensorView& bottom_tensor_view,
154  const WindowLengths& window_lengths,
155  const BottomTensorIndex& window_origin,
157  const PageIdxArray& page_idx,
158  const ValidArray& valids)
159  : bottom_tensor_view_{bottom_tensor_view},
160  window_lengths_{window_lengths},
161  window_origin_{window_origin},
163  page_idx_{page_idx},
164  valids_{valids},
166  {
167 #if 0 // debug
168  // TODO: this use more register for FA, but less register for GEMM
169  // need investigation
170  // only support warp-tile and block-tile
171  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
172 
173  WindowAdaptorCoord window_adaptor_thread_coord_tmp;
174 
175  if constexpr(NDimP == 1)
176  {
177  window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
179  }
180  else if constexpr(NDimP == 2)
181  {
182  window_adaptor_thread_coord_tmp =
184  AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
185  }
186 #else
187  // TODO: this use less register for FA, but more register for GEMM
188  // need investigation
189  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
193 #endif
194 
195  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
196  window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
197  bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
198  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
199  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
200 
201  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
202  // future load/store() calls (might allocate more registers)
203  using Traits = load_store_traits;
204  using SFC_Ys = typename Traits::SFC_Ys;
205 
206  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
207  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
208  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
209 
210  constexpr auto idx_diff_ys =
211  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
212 
213  constexpr auto idx_diff_ps_ys = container_concat(
214  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
215 
217  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
218 
219  pre_computed_coords_(iCoord) =
220  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
221  });
222  }
223 
225 
227  {
228  return TileDstr::is_static();
229  }
230 
231  CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
232 
233  CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
234 
235  CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
236 
237  CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
238 
239  CK_TILE_DEVICE constexpr void
240  set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
241  {
242  bottom_tensor_view_.buf_.p_data_ = data;
243  }
244 
245  // move thread's window adaptor coordinate and bottom tensor coordinate
246  // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
247  template <typename ATopIndex>
249  WindowAdaptorCoord& window_adaptor_thread_coord,
250  BottomTensorCoord& bottom_tensor_thread_coord,
251  const ATopIndex& idx_diff_adaptor_top) const
252  {
253  array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
254 
255  move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
256  window_adaptor_thread_coord,
257  idx_diff_adaptor_top,
258  idx_diff_adaptor_bottom);
259 
260  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
261  bottom_tensor_thread_coord,
262  idx_diff_adaptor_bottom);
263  }
264 
265  // return vector dimension among [y0, y1, ...]
267  {
268  // bottom tensor top dimension vector lengths and strides
269  const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
270  BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
271 
272  // window vector lengths/strides
273  const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
274  const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
275 
276  // window adaptor [p0, p1, ..., y0, y1, ...]
277  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
278  -1};
279  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
280  -1};
281 
282  constexpr auto window_adaptor_bottom_dims =
283  WindowAdaptor::get_bottom_dimension_hidden_ids();
284 
285  set_container_subset(window_adaptor_vector_lengths,
286  window_adaptor_bottom_dims,
287  window_adaptor_bottom_dim_vector_lengths);
288  set_container_subset(window_adaptor_vector_strides,
289  window_adaptor_bottom_dims,
290  window_adaptor_bottom_dim_vector_strides);
291 
292  const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
293  WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
294  window_adaptor_vector_lengths, window_adaptor_vector_strides);
295 
296  // [y0, y1, ...]
297  constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
299  1>::type{};
300 
301  return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
302  get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
303  }
304 
306 
307  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
310  {
311  constexpr auto tile_dstr = TileDstr{};
312  auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
313  load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
314  return dst_tensor;
315  }
316 
317  template <typename DistributedTensor,
318  index_t i_access_unsupport_ = -1,
319  bool oob_conditional_check = true>
320  CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
323  {
324  using Traits = load_store_traits;
325  using vector_t = typename Traits::vector_t;
326  using SFC_Ys = typename Traits::SFC_Ys;
327 
328  constexpr auto tile_dstr = TileDstr{};
329 
330  // loop over thread tensor space [y0, y1, ...]
331  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
333  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
334  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
335 
336  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
337  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
338 
339  // data index [y0, y1, ...]
340  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
341  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
342  const auto page_offset = page_idx_[idx_gather];
343 
344  // read from bottom tensor
345  const vector_t vec_value = [&]() {
346  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
347  {
348  return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
349  bottom_tensor_thread_coord,
350  page_offset,
351  bool_constant<oob_conditional_check>{});
352  }
353  else
354  {
355  return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
356  bottom_tensor_thread_coord,
357  page_offset,
358  valids_[idx_gather],
359  bool_constant<oob_conditional_check>{});
360  }
361  }();
362 #if 1
363  // write into distributed tensor
364  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
365  constexpr auto idx_ys = generate_tuple(
366  [&](auto jj) {
367  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
368  : idx_ys_start[jj];
369  },
370  number<NDimY>{});
371 
372  constexpr index_t d =
373  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
374  Traits::PackedSize;
375 
376  dst_tensor.get_thread_buffer().template at<d>() =
377  vec_value.template get_as<DataType>()[j / Traits::PackedSize];
378  });
379 #else
380  constexpr index_t d =
381  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
382  static_assert(d % Traits::ScalarPerVector == 0);
383 
384  dst_tensor.get_thread_buffer().template get_as<vector_t>()(
385  number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
386 #endif
387  // move thread coordinate
388  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
389  {
390  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
391 
392  constexpr auto forward_step_scatter = generate_tuple(
393  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
394  number<NDimY>{});
395 
396  constexpr auto idx_diff_ps_ys = container_concat(
397  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
398  forward_step_scatter);
399 
401  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
402  }
403  });
404  });
405  }
406 
407  // TODO: currently async load only implemented in inline asm
408  template <typename LdsTileWindow_,
409  index_t i_access_unsupport_ = -1,
410  bool oob_conditional_check = true,
411  bool pre_nop = false>
412  CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
415  bool_constant<pre_nop> = {}) const
416  {
417  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
418  // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
419  using LdsDataType = typename LdsTileWindow::DataType;
420  // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
421 
422  // issues * warps * lanes
423  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
424 
425  const index_t size_per_buf =
426  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
427  make_tuple(number<0>{}, number<0>{}, number<0>{})) *
428  sizeof(LdsDataType);
429 
430  const index_t size_per_wave =
431  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
432  make_tuple(number<0>{}, number<1>{}, number<0>{})) *
433  sizeof(LdsDataType) -
434  size_per_buf;
435 
436  const index_t size_per_issue =
437  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
438  make_tuple(number<1>{}, number<0>{}, number<0>{})) *
439  sizeof(LdsDataType) -
440  size_per_buf;
441 
442  const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
443  m0_set_with_memory(m0_init_value); // This should be wave independent
444 
445  using Traits = load_store_traits;
446 
447  // using vector_type_t = typename Traits::vector_type_t;
448  using vector_t = typename Traits::vector_t;
449  using SFC_Ys = typename Traits::SFC_Ys;
450 
451  LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
452 
453  // loop over thread tensor space [y0, y1, ...]
454  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
456  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
457  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
458 
459  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
460  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
461  constexpr auto pre_nop_ = [&]() {
462  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
463  return bool_constant<true>{};
464  else
465  return bool_constant<false>{};
466  }();
467 
468  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
469  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
470  const auto page_offset = page_idx_[idx_gather];
471 
472  // read from bottom tensor
473  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
474  {
475  get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
476  smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
477  }
478  else
479  {
480  get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
481  smem,
482  bottom_tensor_thread_coord,
483  page_offset,
484  valids_[idx_gather],
485  0,
486  pre_nop_);
487  }
488 
489  // move thread coordinate
490  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
491  {
492  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
493 
494  constexpr auto forward_step_scatter = generate_tuple(
495  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
496  number<NDimY>{});
497 
498  constexpr auto idx_diff_ps_ys = container_concat(
499  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
500  forward_step_scatter);
501 
503  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
504 
505  m0_inc_with_memory(size_per_issue);
506  }
507  });
508  });
509  }
510 
511  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
515  {
516  using Traits = load_store_traits;
517 
518  // using vector_type_t = typename Traits::vector_type_t;
519  using vector_t = typename Traits::vector_t;
520  using SFC_Ys = typename Traits::SFC_Ys;
521 
522  constexpr auto tile_dstr = TileDstr{};
523  // printf("off %d\n", page_idx_[I0]);
524  // loop over thread tensor space [y0, y1, ...]
525  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
526  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
527  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
528 
529  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
530  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
531 
532  // data index [y0, y1, ...]
533  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
534  constexpr auto idx_gather = idx_ys_start[number<0>{}];
535  const auto page_offset = page_idx_[idx_gather];
536 
537  // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
538  // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
539 
540  // read from distributed tensor
541  // vector_type_t vec;
542  vector_t vec_value;
543 
544  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
545  constexpr auto idx_ys = generate_tuple(
546  [&](auto jj) {
547  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
548  : idx_ys_start[jj];
549  },
550  number<NDimY>{});
551 
552  constexpr index_t d =
553  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
554  Traits::PackedSize;
555  // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
556  vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
557  dstr_tensor.get_thread_buffer().template at<d>();
558  });
559 
560  // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
561 
562  // write into bottom tensor
563  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
564  {
565  get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
566  bottom_tensor_thread_coord,
567  page_offset,
568  vec_value,
569  bool_constant<oob_conditional_check>{});
570  }
571  else
572  {
573  get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
574  bottom_tensor_thread_coord,
575  page_offset,
576  valids_[idx_gather],
577  vec_value,
578  bool_constant<oob_conditional_check>{});
579  }
580 
581  // printf("coord_offset:%d, scatter_offset:%d \n",
582  // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
583  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
584  {
585  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
586 
587  constexpr auto forward_step_scatter = generate_tuple(
588  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
589  number<NDimY>{});
590 
591  constexpr auto idx_diff_ps_ys = container_concat(
592  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
593  forward_step_scatter);
594 
596  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
597  }
598  });
599  });
600  }
601 
602  // move thread's botom tensor coordiante
603  // [x0', x1', ... ] ==> [offset]
604  // also move window-origin
606  {
607  window_origin_ += step;
608  BottomTensorIndex step_new = step;
609  step_new(HsGatherDim) = 0;
610  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
611  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
612  pre_computed_coords_(iCoord)(I1),
613  step_new);
614  });
615  }
616 
617  CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
618 
619  CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
620  {
621  if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
622  {
623  valids_ = new_valids;
624  }
625  }
626 
628  const ValidArray& new_valids)
629  {
630  update_page_idx(new_idx);
631  update_valids(new_valids);
632  }
633 
634  CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
635  {
636  window_origin_ = new_window_origin;
637 
638 #if 0 // debug
639  // TODO: this use more register for FA, but less register for GEMM
640  // need investigation
641  // only support warp-tile and block-tile
642  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
643 
644  WindowAdaptorCoord window_adaptor_thread_coord_tmp;
645 
646  if constexpr(NDimP == 1)
647  {
648  window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
649  tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
650  }
651  else if constexpr(NDimP == 2)
652  {
653  window_adaptor_thread_coord_tmp =
654  make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
655  AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
656  }
657 #else
658  // TODO: this use less register for FA, but more register for GEMM
659  // need investigation
660  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
661  tile_dstr_.get_ps_ys_to_xs_adaptor(),
663 #endif
664 
665  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
666  window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
667 
668  bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
669  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
670  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
671 
672  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
673  // future load/store() calls (might allocate more registers)
674  using Traits = load_store_traits;
675  using SFC_Ys = typename Traits::SFC_Ys;
676 
677  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
678  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
679  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
680 
681  constexpr auto idx_diff_ys =
682  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
683 
684  constexpr auto idx_diff_ps_ys = container_concat(
685  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
686 
688  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
689 
690  pre_computed_coords_(iCoord) =
691  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
692  });
693  }
694 
696 
697  // this is the bottom tensor view
698  // [x0', x1', ...] ==> [offset]
700 
701  //
703 
704  // origin ([x0', x1', ...]) of window on bottom tensor
706 
707  // Tile tensor distribution, which contains:
708  // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
709  // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
711 
714 
715  // this contains:
716  // per-thread coordinate for window adaptor
717  // per-thread coordinate for bottom tensor
719 };
720 
721 // TODO: use strategy
722 template <typename TensorView_,
723  typename WindowLengths_,
724  typename StaticTileDistribution_,
725  typename StaticPageIndexArray_,
726  index_t HsGatherDim = 0,
727  index_t NumCoord = 1>
728 CK_TILE_DEVICE constexpr auto
730  const WindowLengths_& window_lengths,
731  const multi_index<TensorView_::get_num_of_dimension()>& origin,
732  const StaticTileDistribution_& tile_distribution,
733  const StaticPageIndexArray_& page_idx,
734  number<HsGatherDim> = {},
735  number<NumCoord> = {})
736 {
737  return tile_scatter_gather<remove_cvref_t<TensorView_>,
738  remove_cvref_t<WindowLengths_>,
739  remove_cvref_t<StaticTileDistribution_>,
740  remove_cvref_t<StaticPageIndexArray_>,
741  std::nullptr_t,
742  HsGatherDim,
743  NumCoord>{
744  tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
745 }
746 
747 template <typename TensorView,
748  typename WindowLengths,
749  typename StaticTileDistribution,
750  typename StaticPageIndexArray,
751  index_t HsGatherDim>
754  const multi_index<TensorView::get_num_of_dimension()>& origin,
755  const StaticTileDistribution& tile_distribution,
756  const StaticPageIndexArray& page_idx,
757  number<HsGatherDim> = {})
758 {
760  tile_window.get_window_lengths(),
761  origin,
762  tile_distribution,
763  page_idx,
764  number<HsGatherDim>{});
765 }
766 
767 template <typename TensorView,
768  typename WindowLengths,
769  typename StaticTileDistribution,
770  typename StaticPageIndexArray,
771  index_t HsGatherDim>
774  const StaticTileDistribution& tile_distribution,
775  const StaticPageIndexArray& page_idx,
776  number<HsGatherDim> = {})
777 {
779  tile_window.get_window_lengths(),
780  tile_window.get_window_origin(),
781  tile_distribution,
782  page_idx,
783  number<HsGatherDim>{});
784 }
785 
786 template <typename TensorView_,
787  typename WindowLengths_,
788  typename StaticTileDistribution_,
789  typename StaticPageIndexArray_,
790  typename StaticValidArray_,
791  index_t HsGatherDim = 0,
792  index_t NumCoord = 1>
793 CK_TILE_DEVICE constexpr auto
795  const WindowLengths_& window_lengths,
796  const multi_index<TensorView_::get_num_of_dimension()>& origin,
797  const StaticTileDistribution_& tile_distribution,
798  const StaticPageIndexArray_& page_idx,
799  const StaticValidArray_& valids,
800  number<HsGatherDim> = {},
801  number<NumCoord> = {})
802 {
803  return tile_scatter_gather<remove_cvref_t<TensorView_>,
804  remove_cvref_t<WindowLengths_>,
805  remove_cvref_t<StaticTileDistribution_>,
806  remove_cvref_t<StaticPageIndexArray_>,
807  remove_cvref_t<StaticValidArray_>,
808  HsGatherDim,
809  NumCoord>{
810  tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
811 }
812 
813 template <typename TensorView,
814  typename WindowLengths,
815  typename StaticTileDistribution,
816  typename StaticPageIndexArray,
817  typename StaticValidArray,
818  index_t HsGatherDim>
821  const multi_index<TensorView::get_num_of_dimension()>& origin,
822  const StaticTileDistribution& tile_distribution,
823  const StaticPageIndexArray& page_idx,
824  const StaticValidArray& valids,
825  number<HsGatherDim> = {})
826 {
828  tile_window.get_window_lengths(),
829  origin,
830  tile_distribution,
831  page_idx,
832  valids,
833  number<HsGatherDim>{});
834 }
835 
836 template <typename TensorView,
837  typename WindowLengths,
838  typename StaticTileDistribution,
839  typename StaticPageIndexArray,
840  typename StaticValidArray,
841  index_t HsGatherDim>
844  const StaticTileDistribution& tile_distribution,
845  const StaticPageIndexArray& page_idx,
846  const StaticValidArray& valids,
847  number<HsGatherDim> = {})
848 {
850  tile_window.get_window_lengths(),
851  tile_window.get_window_origin(),
852  tile_distribution,
853  page_idx,
854  valids,
855  number<HsGatherDim>{});
856 }
857 
858 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constant< b > bool_constant
Definition: integral_constant.hpp:43
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1112
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1052
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:15
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:729
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:97
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
Definition: sequence.hpp:284
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:299
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Definition: numeric.hpp:81
Definition: space_filling_curve.hpp:20
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: debug.hpp:67
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
Definition: tile_scatter_gather.hpp:81
static constexpr index_t PackedSize
Definition: tile_scatter_gather.hpp:104
static constexpr index_t NumAccess
Definition: tile_scatter_gather.hpp:143
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_scatter_gather.hpp:141
static constexpr index_t VectorDimY
Definition: tile_scatter_gather.hpp:106
static constexpr index_t ScalarPerVector
Definition: tile_scatter_gather.hpp:107
This class provides tile (windowed) view and access to the device memory.
Definition: tile_scatter_gather.hpp:41
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_scatter_gather.hpp:605
static constexpr index_t NumAccessPerCoord
Definition: tile_scatter_gather.hpp:149
static constexpr auto I1
Definition: tile_scatter_gather.hpp:59
constexpr CK_TILE_DEVICE tile_scatter_gather(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution, const PageIdxArray &page_idx, const ValidArray &valids)
Definition: tile_scatter_gather.hpp:153
BottomTensorIndex window_origin_
Definition: tile_scatter_gather.hpp:705
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:308
WindowLengths window_lengths_
Definition: tile_scatter_gather.hpp:702
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_scatter_gather.hpp:233
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_scatter_gather.hpp:305
static constexpr index_t NDimBottomTensor
Definition: tile_scatter_gather.hpp:53
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_scatter_gather.hpp:266
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_scatter_gather.hpp:72
PageIdxArray page_idx_
Definition: tile_scatter_gather.hpp:712
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_scatter_gather.hpp:43
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_scatter_gather.hpp:634
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_scatter_gather.hpp:718
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_scatter_gather.hpp:237
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_scatter_gather.hpp:44
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_scatter_gather.hpp:248
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:320
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:512
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray &new_idx, const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:627
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_scatter_gather.hpp:48
TileDstr tile_dstr_
Definition: tile_scatter_gather.hpp:710
ValidArray valids_
Definition: tile_scatter_gather.hpp:713
static constexpr index_t NDimY
Definition: tile_scatter_gather.hpp:56
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_scatter_gather.hpp:50
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_scatter_gather.hpp:52
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_scatter_gather.hpp:226
remove_cvref_t< StaticValidArray_ > ValidArray
Definition: tile_scatter_gather.hpp:46
static constexpr index_t NDimP
Definition: tile_scatter_gather.hpp:55
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_scatter_gather.hpp:42
constexpr CK_TILE_DEVICE tile_scatter_gather()=default
remove_cvref_t< StaticPageIndexArray_ > PageIdxArray
Definition: tile_scatter_gather.hpp:45
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_scatter_gather.hpp:231
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_scatter_gather.hpp:695
static constexpr auto I0
Definition: tile_scatter_gather.hpp:58
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_scatter_gather.hpp:78
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_scatter_gather.hpp:235
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_scatter_gather.hpp:47
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_scatter_gather.hpp:75
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_scatter_gather.hpp:240
BottomTensorView bottom_tensor_view_
Definition: tile_scatter_gather.hpp:699
CK_TILE_DEVICE void update_valids(const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:619
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_scatter_gather.hpp:412
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_scatter_gather.hpp:71
CK_TILE_DEVICE void update_page_idx(const PageIdxArray &new_idx)
Definition: tile_scatter_gather.hpp:617
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_scatter_gather.hpp:224
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:873
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10