/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window_linear.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window_linear.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window_linear.hpp Source File
tile_window_linear.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
19 
20 namespace ck_tile {
21 
22 #define WINDOW_DISPATCH_ISSUE() \
23  if constexpr(i_access < 0) \
24  { \
25  static_for<0, NumAccess, 1>{}([&](auto ia) { issue(ia); }); \
26  } \
27  else \
28  { \
29  static_assert(i_access < NumAccess); \
30  issue(number<i_access>{}); \
31  }
32 
33 //
34 // This version of tile window will pre-cache offset/flags based on need
35 //
36 // LinearBottomDims_, e.g seq<0, 1> for 2d tensor, the last one is linear dim
37 // so last dim can use immediate offset to indexing, can save register
38 // TODO: if using this struct, better use load_raw()/store_raw(), can control
39 // the the immediate offset on the fly
40 // space-filing-curve is non-snaked here!
41 // This struct inherits from tile_window_with_tile_dstr_base, which is an intermediary base class
42 // with the ultimate parent class being tile_window_base.
43 template <typename BottomTensorView_,
44  typename WindowLengths_,
45  typename StaticTileDistribution_,
46  typename LinearBottomDims_>
48  : public tile_window_with_tile_dstr_base<tile_window_linear<BottomTensorView_,
49  WindowLengths_,
50  StaticTileDistribution_,
51  LinearBottomDims_>,
52  BottomTensorView_,
53  WindowLengths_,
54  StaticTileDistribution_>
55 {
57  WindowLengths_,
58  StaticTileDistribution_,
59  LinearBottomDims_>,
60  BottomTensorView_,
61  WindowLengths_,
62  StaticTileDistribution_>;
63 
65 
66  static_assert(LinearBottomDims::size() == Base::BottomTensorView::get_num_of_dimension());
67 
68  static constexpr auto I0 = number<0>{};
69  static constexpr auto I1 = number<1>{};
70 
71  struct traits
72  {
73  private:
74  static constexpr auto get_num_non_linear_access()
75  {
76  constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
77  using ys_to_rhs_major =
78  typename decltype(typename Base::TileDstr{}
79  .get_static_tile_distribution_encoding())::Ys2RHsMajor;
80 
81  constexpr auto non_linear = [&]() {
82  index_t cnt = 1;
83  static_for<0, Base::NDimY, 1>{}([&](auto i_dim_y) {
84  constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
85  constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
86  if constexpr(LinearBottomDims{}[target_h_dim] == 0)
87  {
88  cnt *= sfc_access_lens[i_dim_y];
89  }
90  });
91  return cnt;
92  }();
93 
94  return non_linear;
95  }
96 
97  // example:
98  // non_linear_access_map: sequence<0, 0, 0, 0, 1, 1, 1, 1> for 8 access, totally 2 register
99  // used
100  // -> histogram : sequence<4, 4>
101  // -> prefixsum : seqneuce<0, 4, 8>
102  // non_linear_access_map: sequence<0, 1, 2, 3, 4, 5, 6, 7> for 8 access, totally 8 register
103  // used, will pre-cache 8
104  // -> histogram : sequence<1, 1, 1, 1, 1, 1, 1, 1>
105  // -> prefixsum : seqneuce<0, 1, 2, 3, 4, 5, 6, 7, 8>
106  // non_linear_access_map: sequence<0, 0, 1, 1, 2, 2, 3, 3> for 8 access, totally 4 register
107  // used, will pre-cache 4
108  // -> histogram : sequence<2, 2, 2, 2>
109  // -> prefixsum : seqneuce<0, 2, 4, 6, 8>
110  static constexpr auto get_non_linear_access_map()
111  {
112  constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
113  using ys_to_rhs_major =
114  typename decltype(typename Base::TileDstr{}
115  .get_static_tile_distribution_encoding())::Ys2RHsMajor;
116  constexpr auto non_linear_map = [&]() {
118  index_t cumulative_len_ = 1;
119  index_t cumulative_non_linear_len_ = 1;
120  static_for<0, Base::NDimY, 1>{}([&](auto i_y) {
121  constexpr auto i_dim_y = number<Base::NDimY - i_y - 1>{}; // from right to left
122  constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
123  constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
124  constexpr auto is_linear_dim = LinearBottomDims{}[target_h_dim];
125 
127  constexpr auto current_len_ = sfc_access_lens[i_dim_y];
128 
129  // copy cumulative length as current pattern
130  for(auto i_ = 0; i_ < cumulative_len_; i_++)
131  {
132  current_m_(i_) = m_[i_];
133  }
134  for(auto j_ = 0; j_ < current_len_; j_++)
135  {
136  auto j_offset_ = is_linear_dim ? 0 : j_ * cumulative_non_linear_len_;
137  for(auto i_ = 0; i_ < cumulative_len_; i_++)
138  {
139  m_(j_ * cumulative_len_ + i_) = current_m_[i_] + j_offset_;
140  }
141  }
142  cumulative_len_ *= current_len_;
143  if(!is_linear_dim)
144  cumulative_non_linear_len_ *= current_len_;
145  });
146  return m_;
147  }();
148 
149  return TO_SEQUENCE(non_linear_map, Base::Traits::NumAccess);
150  }
151 
152  static constexpr auto get_non_linear_access_histogram()
153  {
154  constexpr auto m_ = get_non_linear_access_map();
155 
156  constexpr auto r_ =
157  typename arithmetic_sequence_gen<0, get_num_non_linear_access() + 1, 1>::type{};
158 
159  constexpr auto h_ = histogram_sorted_sequence(m_, r_);
160 
161  return h_;
162  }
163 
164  static constexpr auto get_non_linear_access_histogram_prefix_sum()
165  {
166  constexpr auto h_ = get_non_linear_access_histogram();
167  constexpr auto h_prefix_sum_ = prefix_sum_sequence(h_);
168  return h_prefix_sum_;
169  }
170 
171  public:
172  static constexpr index_t NumAccess_NonLinear = get_num_non_linear_access();
173  using AccessMap_NonLinear = decltype(get_non_linear_access_map()); // sequence
174  using AccessHistogram_NonLinear = decltype(get_non_linear_access_histogram());
175  using AccessPrefixSum_NonLinear = decltype(get_non_linear_access_histogram_prefix_sum());
176  };
177 
178  static constexpr index_t NumAccess = Base::Traits::NumAccess;
183 
184  CK_TILE_DEVICE constexpr tile_window_linear() = default;
185 
187  const typename Base::BottomTensorView& bottom_tensor_view,
188  const typename Base::WindowLengths& window_lengths,
189  const typename Base::BottomTensorIndex& window_origin,
190  const typename Base::TileDstr& tile_distribution)
192  {
193  this->bottom_tensor_view_ = bottom_tensor_view;
194  this->window_lengths_ = window_lengths;
195  this->window_origin_ = window_origin;
196  this->tile_dstr_ = tile_distribution;
197  auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
200  make_tuple(get_warp_id(), get_lane_id()),
201  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimY>{})));
202 
203  typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
204  window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
205 
206  auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
207  this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
208 
209  // future load/store() calls (might allocate more registers)
210  using SFC_Ys = typename Base::Traits::SFC_Ys;
211 
212  static_for<0, NumAccess, 1>{}([&](auto i_access) {
213  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
214  constexpr auto need_save_non_linear_coord =
215  bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
216 
217  if constexpr(need_save_non_linear_coord)
218  {
219  cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
220  cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
221  }
222 
223  // TODO: need pad_tensor_view to check which dim need use flag to check
224  // cached flag is independent from non-linear-coord
225  // but need be updated in move_tile, with proper dims
227  this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp);
228 
229  if constexpr(i_access != (NumAccess - 1))
230  {
231  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
232  constexpr auto idx_diff_ps_ys = container_concat(
233  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
234  idx_diff_ys);
235 
237  window_adaptor_thread_coord_tmp,
238  bottom_tensor_thread_coord_tmp,
239  idx_diff_ps_ys);
240  }
241  });
242  }
243 
244  template <index_t i_access>
246  {
247  using SFC_Ys = typename Base::Traits::SFC_Ys;
248  constexpr auto idx_ys = SFC_Ys::get_index(number<i_access>{});
249  using ys_to_rhs_major =
250  typename decltype(typename Base::TileDstr{}
251  .get_static_tile_distribution_encoding())::Ys2RHsMajor;
252 
253  constexpr auto modified_idx_ys = generate_tuple(
254  [&](auto i_dim_y) {
255  constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
256  constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
257  if constexpr(LinearBottomDims{}[target_h_dim] == 0)
258  {
259  return number<0>{};
260  }
261  else
262  {
263  return number<idx_ys[i_dim_y]>{};
264  }
265  },
267 
268  constexpr auto adaptor_ = typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor();
269  constexpr auto idx_ =
270  container_concat(make_tuple(number<0>{}, number<0>{}), modified_idx_ys);
271 
272  return adaptor_.calculate_bottom_index(idx_);
273  }
274 
275  template <index_t i_access>
277  {
278  constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
279  constexpr auto is_pure_linear_tensor =
281  if constexpr(is_pure_linear_tensor)
282  {
283  // this case usually is a LDS window, everything is known at compile tile.
284  // we directly use BottomTensorView transform to compute the offset, in case padding
285  auto bottom_tensor_coord = make_tensor_coordinate(
286  typename Base::BottomTensorView{}.get_tensor_descriptor(), linear_coord);
287  return bottom_tensor_coord.get_offset();
288  }
289  else
290  {
291  // this case usually is a global window, where last dim can be linear
292  // we hack here, that use the original TileDstr to compute the linear offset
293  // ... hoping that there is no extra padding between other dims, which make sense
294  // since that would introduce runtime length (so can't use linear offset)
295  constexpr index_t linear_offset = [&]() {
296  constexpr auto x_idx_ = linear_coord;
297  constexpr auto x_len_ = typename Base::TileDstr{}.get_lengths();
298  static_assert(x_idx_.size() == x_len_.size());
299  constexpr index_t x_dims_ = x_idx_.size();
300  index_t cu_stride_ = 1;
301  index_t cu_offset_ = 0;
302  static_for<0, x_dims_, 1>{}([&](auto i_) {
303  auto r_i_ = number<x_dims_ - i_ - 1>{};
304  cu_offset_ += x_idx_[r_i_] * cu_stride_;
305  cu_stride_ *= x_len_[r_i_];
306  });
307  return cu_offset_;
308  }();
309  return linear_offset;
310  }
311  }
312 
313  template <index_t i_access = -1, bool oob_conditional_check = true>
315  {
316  using vector_t = typename Base::Traits::vector_t;
317  using SFC_Ys = typename Base::Traits::SFC_Ys;
318 
319  constexpr auto tile_dstr = typename Base::TileDstr{};
320 
321  auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
322 
323  auto issue = [&](auto i_access_) {
324  constexpr auto IAccess = number<i_access_>{};
325 
326  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
327  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
328  auto bottom_tensor_flag = cached_flags_[IAccess];
329 
330  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
331 
332  // read from bottom tensor
333  const vector_t vec_value =
334  this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
335  bottom_tensor_thread_coord,
336  linear_offset,
337  bottom_tensor_flag,
338  bool_constant<oob_conditional_check>{});
339 
340  // data index [y0, y1, ...]
341  constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
342  // write into distributed tensor
343  static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
344  constexpr auto idx_ys = generate_tuple(
345  [&](auto jj) {
346  return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j)
347  : idx_diff_ys[jj];
348  },
349  number<Base::NDimY>{});
350 
351  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
352  Base::Traits::PackedSize;
353 
354  dst_tensor.get_thread_buffer().template at<d>() =
355  vec_value
356  .template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
357  });
358  };
359 
361 
362  return dst_tensor;
363  }
364 
365  template <typename DstTile, index_t i_access = -1, bool oob_conditional_check = true>
366  CK_TILE_DEVICE auto load(DstTile& dst_tensor,
367  number<i_access> = {},
369  {
370  using vector_t = typename Base::Traits::vector_t;
371  using SFC_Ys = typename Base::Traits::SFC_Ys;
372 
373  constexpr auto tile_dstr = typename Base::TileDstr{};
374 
375  // auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
376 
377  auto issue = [&](auto i_access_) {
378  constexpr auto IAccess = number<i_access_>{};
379 
380  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
381  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
382  auto bottom_tensor_flag = cached_flags_[IAccess];
383 
384  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
385 
386  // read from bottom tensor
387  const vector_t vec_value =
388  this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
389  bottom_tensor_thread_coord,
390  linear_offset,
391  bottom_tensor_flag,
392  bool_constant<oob_conditional_check>{});
393  // data index [y0, y1, ...]
394  constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
395  // write into distributed tensor
396  static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
397  constexpr auto idx_ys = generate_tuple(
398  [&](auto jj) {
399  return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j)
400  : idx_diff_ys[jj];
401  },
402  number<Base::NDimY>{});
403 
404  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
405  Base::Traits::PackedSize;
406 
407  dst_tensor.get_thread_buffer().template at<d>() =
408  vec_value
409  .template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
410  });
411  };
412 
414 
415  return dst_tensor;
416  }
417 
418  template <typename DstTile,
419  index_t i_access = -1,
420  bool oob_conditional_check = true,
421  bool pre_nop = false>
422  CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
423  number<i_access> = {}, // negative means loop over all num_access
425  bool_constant<pre_nop> = {}) const
426  {
427  using vector_t = typename Base::Traits::vector_t;
428  using SFC_Ys = typename Base::Traits::SFC_Ys;
429  static constexpr index_t YElementSize =
430  typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
431  static_assert(YElementSize % (Base::Traits::PackedSize * Base::Traits::ScalarPerVector) ==
432  0);
433  using vectorized_tbuf =
434  array<vector_t,
435  YElementSize / (Base::Traits::PackedSize * Base::Traits::ScalarPerVector)>;
436 
437  constexpr auto tile_dstr = typename Base::TileDstr{};
438 
439  auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
440 
441  auto issue = [&](auto i_access_) {
442  constexpr auto IAccess = number<i_access_>{};
443  constexpr auto pre_nop_ = [&]() {
444  if constexpr(pre_nop && i_access_ == 0 &&
445  Base::BottomTensorView::buffer_view::get_address_space() ==
446  address_space_enum::global)
447  return bool_constant<true>{};
448  else
449  return bool_constant<false>{};
450  }();
451 
452  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
453  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
454  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
455  auto bottom_tensor_flag = cached_flags_[IAccess];
456 
457  // data index [y0, y1, ...]
458  constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
459  constexpr index_t d =
460  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
461  Base::Traits::PackedSize;
462  static_assert(d % Base::Traits::ScalarPerVector == 0);
463 
464  this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
465  dst_vec_tbuf.template at<d / Base::Traits::ScalarPerVector>(),
466  bottom_tensor_thread_coord,
467  linear_offset ,
468  bottom_tensor_flag,
469  bool_constant<oob_conditional_check>{},
470  pre_nop_);
471 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
472  CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
473  asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag
474 #endif
475  };
476 
478  }
479 
480  // TODO: currently async load only implemented in inline asm
481  template <typename LdsTileWindow_,
482  index_t i_access = -1,
483  bool oob_conditional_check = true,
484  bool pre_nop = false>
485  CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
486  number<i_access> = {},
488  bool_constant<pre_nop> = {}) const
489  {
490  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
491  using LdsDataType = typename LdsTileWindow::DataType;
492 
493  // currently we only support everything is non linear dim
494  // actually it's not performant if we have linear dim(e.g. fast changing)
495  static_assert(NumAccess_NonLinear == NumAccess);
496  static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
497  address_space_enum::global);
498 
499  // issues * warps * lanes
500  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
501 
502  const index_t size_per_buf =
503  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
504  make_tuple(number<0>{}, number<0>{}, number<0>{})) *
505  sizeof(LdsDataType);
506 
507  const index_t size_per_wave =
508  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
509  make_tuple(number<0>{}, number<1>{}, number<0>{})) *
510  sizeof(LdsDataType) -
511  size_per_buf;
512 
513  const index_t size_per_issue =
514  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
515  make_tuple(number<1>{}, number<0>{}, number<0>{})) *
516  sizeof(LdsDataType) -
517  size_per_buf;
518 
519  const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
520  m0_set_with_memory(m0_init_value); // This should be wave independent
521 
522  using vector_t = typename Base::Traits::vector_t;
523 
524  LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
525 
526  // loop over thread tensor space [y0, y1, ...]
527  auto issue = [&](auto i_access_) {
528  constexpr auto IAccess = number<i_access_>{};
529  constexpr auto pre_nop_ = [&]() {
530  if constexpr(pre_nop && i_access_ == 0)
531  return bool_constant<true>{};
532  else
533  return bool_constant<false>{};
534  }();
535 
536  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
537  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
538  auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway
539 
540  // read from bottom tensor
541  this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
542  smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
543 
544  // move thread coordinate
545  if constexpr(i_access_ != (NumAccess - 1))
546  {
547  m0_inc_with_memory(size_per_issue);
548  }
549  };
550 
552  }
553 
554  template <typename LdsTileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
555  CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
556  number<i_access> = {},
558  {
559  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
560  using LdsDataType = typename LdsTileWindow::DataType;
561  using vector_t = typename traits::vector_t;
562 
563  static_assert(NumAccess_NonLinear == NumAccess, "Unsupported configuration");
564  static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
565  address_space_enum::global,
566  "Requires global memory");
567 
568  // Precompute invariant values outside the lambda
569  const auto window_origin = lds_tile.get_window_origin();
570  const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
571  const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
572  auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
573 
574  auto issue = [&](auto i_access_) {
575  constexpr auto IAccess = number<i_access_>{};
576  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
577 
578  // Use precomputed values
579  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
580  auto window_adaptor_coord = cached_window_adaptor_coords_[non_linear_id];
581  auto bottom_tensor_flag = cached_flags_[IAccess];
582 
583  auto lds_bottom_tensor_thread_idx =
584  window_origin + window_adaptor_coord.get_bottom_index();
585  const auto lds_coord =
586  make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
587 
588  CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
589 
590  // Read from bottom tensor
591  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
592  smem,
593  bottom_tensor_thread_coord,
594  0,
595  bottom_tensor_flag,
596  bool_constant<oob_conditional_check>{});
597  };
598 
600  }
601 
602  template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
604  {
605  constexpr auto tile_dstr = typename Base::TileDstr{};
606  auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
607  this->template load_transpose_linear<Policy>(
609  return dst_tensor;
610  }
611 
612  template <typename Policy,
613  typename DistributedTensor,
614  index_t i_access = -1,
615  bool oob_conditional_check = true>
616  CK_TILE_DEVICE auto load_transpose_linear(DistributedTensor& dst_tensor,
617  number<i_access> = {},
619  {
620  using vector_t = typename traits::vector_t;
621  using SFC_Ys = typename traits::SFC_Ys;
622 
623  constexpr auto tile_dstr = typename Base::TileDstr{};
624 
625  constexpr auto group_func = Policy::group_func;
626 
627  auto issue = [&](auto i_access_) {
628  constexpr auto IAccess = number<i_access_>{};
629  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
630  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
631  auto bottom_tensor_flag = cached_flags_[IAccess];
632 
633  constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
634 
635  // read from bottom tensor
636  const vector_t vec_value =
637  this->get_bottom_tensor_view().template get_transpose_vectorized_elements<vector_t>(
638  bottom_tensor_thread_coord, 0);
639  // write into distributed tensor
640  static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
641  constexpr auto idx_ys = generate_tuple(
642  [&](auto jj) {
643  return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
644  },
645  number<Base::NDimY>{});
646 
647  constexpr index_t linear_distributed_index =
648  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
649  dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
650  vec_value.template get_as<typename Base::DataType>()[j];
651  });
652  };
654  }
655 
656  template <index_t i_access = -1, bool oob_conditional_check = true>
658  typename Base::TileDstr>& dstr_tensor,
659  number<i_access> = {},
661  {
662 
663  using vector_t = typename Base::Traits::vector_t;
664  using SFC_Ys = typename Base::Traits::SFC_Ys;
665 
666  constexpr auto tile_dstr = typename Base::TileDstr{};
667 
668  // loop over thread tensor space [y0, y1, ...]
669  auto issue = [&](auto i_access_) {
670  constexpr auto IAccess = number<i_access_>{};
671  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
672  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
673  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
674  auto bottom_tensor_flag = cached_flags_[IAccess];
675  // data index [y0, y1, ...]
676  constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
677 
678  // read from distributed tensor
679  vector_t vec_value;
680 
681  static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
682  constexpr auto idx_ys = generate_tuple(
683  [&](auto jj) {
684  return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
685  : idx_ys_start[jj];
686  },
687  number<Base::NDimY>{});
688 
689  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
690  Base::Traits::PackedSize;
691 
692  vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
693  dstr_tensor.get_thread_buffer().template at<d>();
694  });
695 
696  // write into bottom tensor
697  this->get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
698  bottom_tensor_thread_coord,
699  linear_offset,
700  bottom_tensor_flag,
701  vec_value,
702  bool_constant<oob_conditional_check>{});
703  };
704 
706  }
707 
708  template <index_t i_access = -1>
709  CK_TILE_DEVICE void
711  dstr_tensor,
712  number<i_access> = {}) const
713  {
714  using vector_t = typename Base::Traits::vector_t;
715  using SFC_Ys = typename Base::Traits::SFC_Ys;
716 
717  constexpr auto tile_dstr = typename Base::TileDstr{};
718  static constexpr bool oob_conditional_check = true;
719 
720  // loop over thread tensor space [y0, y1, ...]
721  auto issue = [&](auto i_access_) {
722  constexpr auto IAccess = number<i_access_>{};
723  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
724  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
725  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
726  auto bottom_tensor_flag = cached_flags_[IAccess];
727 
728  // data index [y0, y1, ...]
729  constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
730 
731  // read from distributed tensor
732  vector_t vec_value;
733  static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
734  constexpr auto idx_ys = generate_tuple(
735  [&](auto jj) {
736  return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
737  : idx_ys_start[jj];
738  },
739  number<Base::NDimY>{});
740  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
741  Base::Traits::PackedSize;
742  vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
743  dstr_tensor.get_thread_buffer().template at<d>();
744  });
745 
746  // write into bottom tensor
747  this->get_bottom_tensor_view()
748  .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
749  bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value);
750  };
751 
753  }
754 
755  template <index_t i_access = -1, bool oob_conditional_check = true>
756  CK_TILE_DEVICE void
758  dstr_tensor,
759  number<i_access> = {},
761  {
762 
763  using vector_t = typename Base::Traits::vector_t;
764  using SFC_Ys = typename Base::Traits::SFC_Ys;
765 
766  constexpr auto tile_dstr = typename Base::TileDstr{};
767 
768  // loop over thread tensor space [y0, y1, ...]
769  auto issue = [&](auto i_access_) {
770  constexpr auto IAccess = number<i_access_>{};
771  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
772  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
773  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
774  auto bottom_tensor_flag = cached_flags_[IAccess];
775 
776  // data index [y0, y1, ...]
777  constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
778 
779  // read from distributed tensor
780  vector_t vec_value;
781 
782  static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
783  constexpr auto idx_ys = generate_tuple(
784  [&](auto jj) {
785  return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
786  : idx_ys_start[jj];
787  },
788  number<Base::NDimY>{});
789 
790  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
791  Base::Traits::PackedSize;
792 
793  vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
794  dstr_tensor.get_thread_buffer().template at<d>();
795  });
796 
797  // write into bottom tensor
798  this->get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
799  bottom_tensor_thread_coord,
800  linear_offset,
801  bottom_tensor_flag,
802  vec_value,
803  bool_constant<oob_conditional_check>{});
804  };
805 
807  }
808 
809  template <index_t i_access = -1, bool oob_conditional_check = true, bool pre_nop = false>
810  CK_TILE_DEVICE void
812  dstr_tensor,
813  number<i_access> = {},
815  bool_constant<pre_nop> = {}) const
816  {
817 
818  using vector_t = typename Base::Traits::vector_t;
819  using SFC_Ys = typename Base::Traits::SFC_Ys;
820 
821  constexpr auto tile_dstr = typename Base::TileDstr{};
822 
823  // loop over thread tensor space [y0, y1, ...]
824  auto issue = [&](auto i_access_) {
825  constexpr auto IAccess = number<i_access_>{};
826  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
827  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
828  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
829  auto bottom_tensor_flag = cached_flags_[IAccess];
830 
831  // data index [y0, y1, ...]
832  constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
833 
834  // read from distributed tensor
835  vector_t vec_value;
836 
837  static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
838  constexpr auto idx_ys = generate_tuple(
839  [&](auto jj) {
840  return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
841  : idx_ys_start[jj];
842  },
843  number<Base::NDimY>{});
844 
845  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
846  Base::Traits::PackedSize;
847 
848  vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
849  dstr_tensor.get_thread_buffer().template at<d>();
850  });
851 
852  // write into bottom tensor
853  this->get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
854  bottom_tensor_thread_coord,
855  linear_offset,
856  bottom_tensor_flag,
857  vec_value,
858  bool_constant<oob_conditional_check>{},
859  bool_constant<pre_nop>{});
860  };
861 
863  }
864  // *_extended() functions acts like a virtual function with a default implementation exisiting
865  // in the base class
867  {
868  static_for<0, NumAccess, 1>{}([&](auto i_access) {
869  constexpr auto IAccess = number<i_access>{};
870  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
871  constexpr auto need_update_non_linear_coord =
872  bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
873 
874  if constexpr(need_update_non_linear_coord)
875  {
876  move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
877  cached_coords_(non_linear_id),
878  step);
879  }
880 
881  // move the current coord with linear_coords
882  auto tmp_coords = cached_coords_[non_linear_id];
883  constexpr auto linear_coord = get_bottom_linear_coordinate(IAccess);
885  this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord);
886 
888  this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords);
889  });
890  }
891 
893  {
894  auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
895  typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor(),
897  make_tuple(get_warp_id(), get_lane_id()),
898  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimY>{})));
899 
900  typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
901  this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
902 
903  auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
904  this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
905 
906  // future load/store() calls (might allocate more registers)
907  using SFC_Ys = typename Base::Traits::SFC_Ys;
908 
909  static_for<0, NumAccess, 1>{}([&](auto i_access) {
910  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
911  constexpr auto need_save_non_linear_coord =
912  bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
913 
914  if constexpr(need_save_non_linear_coord)
915  {
916  cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
917  cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
918  }
919 
920  if constexpr(i_access != (NumAccess - 1))
921  {
922  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
923  constexpr auto idx_diff_ps_ys = container_concat(
924  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
925  idx_diff_ys);
926 
928  window_adaptor_thread_coord_tmp,
929  bottom_tensor_thread_coord_tmp,
930  idx_diff_ps_ys);
931  }
932  });
933  }
934 
935  // this contains:
940 };
941 
942 #undef WINDOW_DISPATCH_ISSUE
943 
944 namespace impl {
945 template <address_space_enum, index_t len_>
947 {
949 };
950 
951 template <index_t len_>
952 struct default_linear_bottom_dims_impl<address_space_enum::global, len_>
953 {
954  // global default to seq<0,0,....1>
955  using type = typename sequence_merge<typename uniform_sequence_gen<len_ - 1, 0>::type,
957 };
958 
959 template <index_t len_>
960 struct default_linear_bottom_dims_impl<address_space_enum::lds, len_>
961 {
962  // lds default to seq<1,1.....1>
964 };
965 } // namespace impl
966 
967 template <typename TensorView_>
969  typename impl::default_linear_bottom_dims_impl<TensorView_::buffer_view::get_address_space(),
970  TensorView_::get_num_of_dimension()>::type;
971 
972 // if using this API, will create a tile_window_linear
973 // this structure can have the chance to use immediate value, save register
974 // need pass in LinearBottomDims_ properly to control which dim is linear
975 // so to generate a constexpr offset as linear_offset for this dim
976 // (and finally pass to the immediate offset of buffer/lds instruction)
977 //
978 // Note: there is no internal check for which dim is OK to use linear offset
979 // user must make sure by themselves
980 //
981 // e.g.
982 // 2d global matrix, set LinearBottomDims_=seq<0, 1>, the last dim will generate
983 // immediate offset if each thread has multiple issue along last dim
984 //
985 // 2d LDS buffer, set LinearBottomDims_=seq<1, 1>, then only one vgpr used as offset
986 // everything else is just using immediate offset.
987 //
988 template <typename TensorView_,
989  typename WindowLengths_,
990  typename StaticTileDistribution_,
991  typename LinearBottomDims_ = default_linear_bottom_dims<TensorView_>>
992 CK_TILE_DEVICE constexpr auto
994  const WindowLengths_& window_lengths,
995  const multi_index<TensorView_::get_num_of_dimension()>& origin,
996  const StaticTileDistribution_& tile_distribution,
997  LinearBottomDims_ = {})
998 {
999  static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension());
1000  return tile_window_linear<remove_cvref_t<TensorView_>,
1001  remove_cvref_t<WindowLengths_>,
1002  remove_cvref_t<StaticTileDistribution_>,
1003  remove_cvref_t<LinearBottomDims_>>{
1004  tensor_view, window_lengths, origin, tile_distribution};
1005 }
1006 
1007 template <
1008  typename TileWindow_,
1009  typename StaticTileDistribution_,
1010  typename LinearBottomDims_ = default_linear_bottom_dims<typename TileWindow_::BottomTensorView>>
1011 CK_TILE_DEVICE constexpr auto
1012 make_tile_window_linear(const TileWindow_& tile_window,
1013  const StaticTileDistribution_& tile_distribution,
1014  LinearBottomDims_ = {})
1015 {
1016  return make_tile_window_linear(tile_window.get_bottom_tensor_view(),
1017  tile_window.get_window_lengths(),
1018  tile_window.get_window_origin(),
1019  tile_distribution,
1020  LinearBottomDims_{});
1021 }
1022 
1023 // this version must not be called under a constexpr context
1024 template <typename TensorView_,
1025  typename WindowLengths_,
1026  typename StaticTileDistribution_,
1027  typename LinearBottomDims_ = default_linear_bottom_dims<TensorView_>>
1028 CK_TILE_DEVICE auto
1030  const WindowLengths_& window_lengths,
1031  const multi_index<TensorView_::get_num_of_dimension()>& origin,
1032  const StaticTileDistribution_& tile_distribution,
1033  LinearBottomDims_ = {})
1034 {
1035  static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension());
1036  auto w = tile_window_linear<remove_cvref_t<TensorView_>,
1037  remove_cvref_t<WindowLengths_>,
1038  remove_cvref_t<StaticTileDistribution_>,
1039  remove_cvref_t<LinearBottomDims_>>{
1040  tensor_view, window_lengths, origin, tile_distribution};
1041  w.init_raw();
1042  return w;
1043 }
1044 
1045 template <
1046  typename TileWindow_,
1047  typename StaticTileDistribution_,
1048  typename LinearBottomDims_ = default_linear_bottom_dims<typename TileWindow_::BottomTensorView>>
1049 CK_TILE_DEVICE constexpr auto
1050 make_tile_window_linear_raw(const TileWindow_& tile_window,
1051  const StaticTileDistribution_& tile_distribution,
1052  LinearBottomDims_ = {})
1053 {
1054  return make_tile_window_linear_raw(tile_window.get_bottom_tensor_view(),
1055  tile_window.get_window_lengths(),
1056  tile_window.get_window_origin(),
1057  tile_distribution,
1058  LinearBottomDims_{});
1059 }
1060 
1061 template <typename TensorView_,
1062  typename WindowLengths_,
1063  typename StaticTileDistribution_,
1064  typename LinearBottomDims_>
1067  window,
1068  const typename tile_window_linear<TensorView_,
1069  WindowLengths_,
1070  StaticTileDistribution_,
1071  LinearBottomDims_>::BottomTensorIndex& step)
1072 {
1073  window.move(step);
1074 }
1075 
1084 template <typename T>
1086 {
1087 };
1088 
1100 template <typename BottomTensorView_,
1101  typename WindowLengths_,
1102  typename StaticTileDistribution_,
1103  typename LinearBottomDims_>
1105  WindowLengths_,
1106  StaticTileDistribution_,
1107  LinearBottomDims_>> : std::true_type
1108 {
1109 };
1110 
1118 template <typename T>
1120 
1121 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
Definition: cluster_descriptor.hpp:13
typename impl::default_linear_bottom_dims_impl< TensorView_::buffer_view::get_address_space(), TensorView_::get_num_of_dimension()>::type default_linear_bottom_dims
Definition: tile_window_linear.hpp:970
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constant< b > bool_constant
Definition: integral_constant.hpp:43
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:979
constexpr CK_TILE_HOST_DEVICE bool coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition: tensor_coordinate.hpp:79
CK_TILE_DEVICE auto make_tile_window_linear_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:1029
constexpr bool is_tile_window_linear_v
Helper variable template to check if a type is a linear tile window.
Definition: tile_window_linear.hpp:1119
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
constexpr CK_TILE_HOST_DEVICE auto histogram_sorted_sequence(SeqSortedSamples, sequence< r, rs... >)
Definition: sequence.hpp:1099
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
constexpr auto prefix_sum_sequence(Seq)
Definition: sequence.hpp:905
bool_constant< false > false_type
Definition: integral_constant.hpp:63
bool_constant< true > true_type
Definition: integral_constant.hpp:62
Definition: sequence.hpp:284
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
typename sequence_merge< typename uniform_sequence_gen< len_ - 1, 0 >::type, sequence< 1 > >::type type
Definition: tile_window_linear.hpp:956
typename uniform_sequence_gen< len_, 1 >::type type
Definition: tile_window_linear.hpp:963
Definition: tile_window_linear.hpp:947
typename uniform_sequence_gen< len_, 0 >::type type
Definition: tile_window_linear.hpp:948
Type trait to determine if a type is a linear tile window.
Definition: tile_window_linear.hpp:1086
Definition: math.hpp:98
Definition: sequence.hpp:233
Definition: sequence.hpp:49
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window_base.hpp:67
Definition: tile_window_linear.hpp:72
decltype(get_non_linear_access_histogram_prefix_sum()) AccessPrefixSum_NonLinear
Definition: tile_window_linear.hpp:175
decltype(get_non_linear_access_map()) AccessMap_NonLinear
Definition: tile_window_linear.hpp:173
static constexpr index_t NumAccess_NonLinear
Definition: tile_window_linear.hpp:172
decltype(get_non_linear_access_histogram()) AccessHistogram_NonLinear
Definition: tile_window_linear.hpp:174
Definition: tile_window_linear.hpp:55
static constexpr auto I0
Definition: tile_window_linear.hpp:68
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex &)
Definition: tile_window_linear.hpp:892
CK_TILE_DEVICE auto load(number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:314
constexpr CK_TILE_DEVICE tile_window_linear()=default
array< typename Base::WindowAdaptorCoord, traits::NumAccess_NonLinear > cached_window_adaptor_coords_
Definition: tile_window_linear.hpp:938
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:555
CK_TILE_DEVICE void load_raw(DstTile &dst_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window_linear.hpp:422
static constexpr CK_TILE_DEVICE index_t get_bottom_linear_offset(number< i_access >)
Definition: tile_window_linear.hpp:276
CK_TILE_DEVICE auto load_transpose() const
Definition: tile_window_linear.hpp:603
typename traits::AccessHistogram_NonLinear AccessHistogram_NonLinear
Definition: tile_window_linear.hpp:181
typename traits::AccessMap_NonLinear AccessMap_NonLinear
Definition: tile_window_linear.hpp:180
constexpr CK_TILE_DEVICE tile_window_linear(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin, const typename Base::TileDstr &tile_distribution)
Definition: tile_window_linear.hpp:186
static constexpr index_t NumAccess
Definition: tile_window_linear.hpp:178
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}) const
Definition: tile_window_linear.hpp:710
array< bool, Base::Traits::NumAccess > cached_flags_
Definition: tile_window_linear.hpp:939
static constexpr CK_TILE_DEVICE auto get_bottom_linear_coordinate(number< i_access >)
Definition: tile_window_linear.hpp:245
CK_TILE_DEVICE void update(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:757
CK_TILE_DEVICE void store(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:657
CK_TILE_DEVICE void update_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window_linear.hpp:811
typename traits::AccessPrefixSum_NonLinear AccessPrefixSum_NonLinear
Definition: tile_window_linear.hpp:182
CK_TILE_DEVICE auto load_transpose_linear(DistributedTensor &dst_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:616
static constexpr index_t NumAccess_NonLinear
Definition: tile_window_linear.hpp:179
CK_TILE_DEVICE auto load(DstTile &dst_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:366
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex &step)
Definition: tile_window_linear.hpp:866
array< typename Base::BottomTensorCoord, traits::NumAccess_NonLinear > cached_coords_
Definition: tile_window_linear.hpp:936
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window_linear.hpp:485
remove_cvref_t< LinearBottomDims_ > LinearBottomDims
Definition: tile_window_linear.hpp:64
static constexpr auto I1
Definition: tile_window_linear.hpp:69
Definition: tile_window_base.hpp:94
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_window_base.hpp:129
Definition: sequence.hpp:311
typename sequence_gen< NSize, F >::type type
Definition: sequence.hpp:317
#define WINDOW_DISPATCH_ISSUE()
Definition: tile_window_linear.hpp:22
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10