/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window.hpp Source File
tile_window.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
19 
20 namespace ck_tile {
21 
33 template <typename BottomTensorView_,
34  typename WindowLengths_,
35  typename StaticTileDistribution_,
36  index_t NumCoord>
39  tile_window_with_static_distribution<BottomTensorView_,
40  WindowLengths_,
41  StaticTileDistribution_,
42  NumCoord>,
43  BottomTensorView_,
44  WindowLengths_,
45  StaticTileDistribution_>
46 {
48  tile_window_with_static_distribution<BottomTensorView_,
49  WindowLengths_,
50  StaticTileDistribution_,
51  NumCoord>,
52  BottomTensorView_,
53  WindowLengths_,
54  StaticTileDistribution_>;
55 
56  static constexpr auto I0 = number<0>{};
57  static constexpr auto I1 = number<1>{};
58  static_assert(NumCoord == 1);
59 
60  static_assert(Base::Traits::NumAccess % NumCoord == 0,
61  "wrong! # of access is not divisible by NumCoord");
62  static constexpr index_t NumAccessPerCoord = Base::Traits::NumAccess / NumCoord;
63 
65 
67  const typename Base::BottomTensorView& bottom_tensor_view,
68  const typename Base::WindowLengths& window_lengths,
69  const typename Base::BottomTensorIndex& window_origin,
70  const typename Base::TileDstr& tile_distribution)
72  {
73 
74  this->window_origin_ = window_origin;
75  this->window_lengths_ = window_lengths;
76  this->bottom_tensor_view_ = bottom_tensor_view;
77  this->tile_dstr_ = tile_distribution;
78  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
82 
83  typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
84  window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
85 
86  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
87  bottom_tensor_view.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
88 
89  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
90  // future load/store() calls (might allocate more registers)
91  using Traits = typename Base::Traits;
92  using SFC_Ys = typename Traits::SFC_Ys;
93 
94  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
95  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
96  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
97 
98  constexpr auto idx_diff_ys =
99  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
100 
101  constexpr auto idx_diff_ps_ys = container_concat(
102  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
103  idx_diff_ys);
104 
106  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
107 
108  pre_computed_coords_(iCoord) =
109  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
110  });
111  }
112 
113  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
116  {
117  constexpr auto tile_dstr = typename Base::TileDstr{};
118  auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
119  load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
120  return dst_tensor;
121  }
122 
123  template <typename DistributedTensor,
124  index_t i_access_unsupport_ = -1,
125  bool oob_conditional_check = true>
126  CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
129  {
130  using Traits = typename Base::Traits;
131  using vector_t = typename Traits::vector_t;
132  using SFC_Ys = typename Traits::SFC_Ys;
133 
134  constexpr auto tile_dstr = typename Base::TileDstr{};
135 
136  // loop over thread tensor space [y0, y1, ...]
137  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
139  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
140  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
141 
142  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
143  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
144 
145  // data index [y0, y1, ...]
146  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
147 
148  // read from bottom tensor
149  const vector_t vec_value =
150  this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
151  bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
152  // write into distributed tensor
153  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
154  constexpr auto idx_ys = generate_tuple(
155  [&](auto jj) {
156  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
157  : idx_ys_start[jj];
158  },
159  number<Base::NDimY>{});
160 
161  constexpr index_t d =
162  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
163  Traits::PackedSize;
164 
165  dst_tensor.get_thread_buffer().template at<d>() =
166  vec_value
167  .template get_as<typename Base::DataType>()[j / Traits::PackedSize];
168  });
169  // move thread coordinate
170  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
171  {
172  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
173 
174  constexpr auto idx_diff_ps_ys = container_concat(
175  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
176  idx_diff_ys);
177 
179  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
180  }
181  });
182  });
183  }
184 
185  template <typename DstTile,
186  index_t i_access_unsupport_ = -1,
187  bool oob_conditional_check = true,
188  bool pre_nop = false>
189  CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
192  bool_constant<pre_nop> = {}) const
193  {
194  using Traits = typename Base::Traits;
195  using vector_t = typename Traits::vector_t;
196  using SFC_Ys = typename Traits::SFC_Ys;
197  static constexpr index_t YElementSize =
198  typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
199  static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0);
200  using vectorized_tbuf =
201  array<vector_t, YElementSize / (Traits::PackedSize * Traits::ScalarPerVector)>;
202 
203  constexpr auto tile_dstr = typename Base::TileDstr{};
204 
205  auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
206 
207  // loop over thread tensor space [y0, y1, ...]
208  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
210  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
211  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
212 
213  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
214  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
215  constexpr auto pre_nop_ = [&]() {
216  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
217  return bool_constant<true>{};
218  else
219  return bool_constant<false>{};
220  }();
221 
222  // data index [y0, y1, ...]
223  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
224  constexpr index_t d =
225  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
226  Traits::PackedSize;
227  static_assert(d % Traits::ScalarPerVector == 0);
228 
229  this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
230  dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
231  bottom_tensor_thread_coord,
232  0 ,
233  bool_constant<oob_conditional_check>{},
234  pre_nop_);
235 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
236  CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
237  asm volatile(
238  ""); // this is starting from rocm-6.2, but same sympton, reuse this flag
239 #endif
240  // move thread coordinate
241  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
242  {
243  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
244 
245  constexpr auto idx_diff_ps_ys = container_concat(
246  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
247  idx_diff_ys);
248 
250  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
251  }
252  });
253  });
254  }
255 
256  // TODO: currently async load only implemented in inline asm
257  template <typename LdsTileWindow_,
258  index_t i_access_unsupport_ = -1,
259  bool oob_conditional_check = true,
260  bool pre_nop = false>
261  CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
264  bool_constant<pre_nop> = {}) const
265  {
266  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
267  // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
268  using LdsDataType = typename LdsTileWindow::DataType;
269  // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
270 
271  // issues * warps * lanes
272  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
273 
274  const index_t size_per_buf =
275  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
276  make_tuple(number<0>{}, number<0>{}, number<0>{})) *
277  sizeof(LdsDataType);
278 
279  const index_t size_per_wave =
280  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
281  make_tuple(number<0>{}, number<1>{}, number<0>{})) *
282  sizeof(LdsDataType) -
283  size_per_buf;
284 
285  const index_t size_per_issue =
286  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
287  make_tuple(number<1>{}, number<0>{}, number<0>{})) *
288  sizeof(LdsDataType) -
289  size_per_buf;
290 
291  // Use VALU so the compiler can optimize redundant/repeated computations
292  const index_t m0_init_value =
293  size_per_buf + size_per_wave * get_warp_id(/*ReturnSgpr=*/bool_constant<false>{});
295  __builtin_amdgcn_readfirstlane(m0_init_value)); // This should be wave independent
296 
297  using Traits = typename Base::Traits;
298 
299  using vector_t = typename Traits::vector_t;
300  using SFC_Ys = typename Traits::SFC_Ys;
301 
302  LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
303 
304  // loop over thread tensor space [y0, y1, ...]
305  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
307  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
308  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
309 
310  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
311  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
312  constexpr auto pre_nop_ = [&]() {
313  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
314  return bool_constant<true>{};
315  else
316  return bool_constant<false>{};
317  }();
318 
319  // read from bottom tensor
320  this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
321  smem, bottom_tensor_thread_coord, 0, pre_nop_);
322 
323  // move thread coordinate
324  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
325  {
326  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
327 
328  constexpr auto idx_diff_ps_ys = container_concat(
329  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
330  idx_diff_ys);
331 
333  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
334 
335  m0_inc_with_memory(size_per_issue);
336  }
337  });
338  });
339  }
340 
341  template <typename LdsTileWindow_,
342  index_t i_access_unsupport_ = -1,
343  bool oob_conditional_check = true>
344  CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
347  {
348  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
349  using LdsDataType = typename LdsTileWindow::DataType;
350  using Traits = typename Base::Traits;
351 
352  using vector_t = typename Traits::vector_t;
353  using SFC_Ys = typename Traits::SFC_Ys;
354 
355  // Precompute invariant values outside loops
356  const auto window_origin = lds_tile.get_window_origin();
357  const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
358  const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
359  auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
360 
361  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
362  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
363  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
364 
365  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
366  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
367 
368  // Use precomputed window origin
369  auto lds_bottom_tensor_thread_idx =
370  window_origin + window_adaptor_thread_coord.get_bottom_index();
371 
372  // Use precomputed tensor descriptor
373  const auto lds_coord =
374  make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
375 
376  // Calculate SMEM address using base pointer
377  CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
378 
379  // Write into bottom tensor
380  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
381  smem,
382  bottom_tensor_thread_coord,
383  number<0>{},
384  bool_constant<oob_conditional_check>{});
385 
386  // Move thread coordinate if not last access
387  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
388  {
389  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
390  constexpr auto idx_diff_ps_ys = container_concat(
391  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
392  idx_diff_ys);
393 
395  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
396  }
397  });
398  });
399  }
400 
401  template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
403  {
404  constexpr auto tile_dstr = typename Base::TileDstr{};
405  auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
406  this->template load_transpose<Policy>(
408  return dst_tensor;
409  }
410 
411  template <typename Policy,
412  typename DistributedTensor,
413  index_t i_access_unsupport_ = -1,
414  bool oob_conditional_check = true>
415  CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor,
418  {
419  using Traits = typename Base::Traits;
420  using vector_t = typename Traits::vector_t;
421  using SFC_Ys = typename Traits::SFC_Ys;
422 
423  constexpr auto tile_dstr = typename Base::TileDstr{};
424 
425  constexpr auto group_func = Policy::group_func;
426 
427  // loop over thread tensor space [y0, y1, ...]
428  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
430  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
431  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
432 
433  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
434  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
435 
436  // data index [y0, y1, ...]
437  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
438 
439  // read from bottom tensor
440  const vector_t vec_value =
441  this->get_bottom_tensor_view()
442  .template get_transpose_vectorized_elements<vector_t>(
443  bottom_tensor_thread_coord, 0);
444  // write into distributed tensor
445  static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
446  constexpr auto orig_idx_ys = generate_tuple(
447  [&](auto jj) {
448  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
449  : idx_ys_start[jj];
450  },
451  number<Base::NDimY>{});
452 
453  constexpr auto grouped_idx_ys = group_func(orig_idx_ys);
454 
455  constexpr index_t linear_distributed_index =
456  tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
457 
458  dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
459  vec_value.template get_as<typename Base::DataType>()[j];
460  });
461  // move thread coordinate
462  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
463  {
464  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
465 
466  constexpr auto idx_diff_ps_ys = container_concat(
467  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
468  idx_diff_ys);
469 
471  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
472  }
473  });
474  });
475  }
476 
477  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
479  typename Base::TileDstr>& dstr_tensor,
482  {
483  using Traits = typename Base::Traits;
484 
485  using vector_t = typename Traits::vector_t;
486  using SFC_Ys = typename Traits::SFC_Ys;
487 
488  constexpr auto tile_dstr = typename Base::TileDstr{};
489 
490  // loop over thread tensor space [y0, y1, ...]
491  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
492  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
493  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
494 
495  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
496  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
497 
498  // data index [y0, y1, ...]
499  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
500 
501  // read from distributed tensor
502  // vector_type_t vec;
503  vector_t vec_value;
504 
505  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
506  constexpr auto idx_ys = generate_tuple(
507  [&](auto jj) {
508  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
509  : idx_ys_start[jj];
510  },
511  number<Base::NDimY>{});
512 
513  constexpr index_t d =
514  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
515  Traits::PackedSize;
516 
517  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
518  dstr_tensor.get_thread_buffer().template at<d>();
519  });
520 
521  // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
522 
523  // write into bottom tensor
524  this->get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
525  bottom_tensor_thread_coord,
526  0,
527  vec_value,
528  bool_constant<oob_conditional_check>{});
529 
530  // move thread coordinate
531  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
532  {
533  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
534 
535  constexpr auto idx_diff_ps_ys = container_concat(
536  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
537  idx_diff_ys);
538 
540  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
541  }
542  });
543  });
544  }
545 
546  template <index_t i_access_unsupport_ = -1>
547  CK_TILE_DEVICE void
549  dstr_tensor,
550  number<i_access_unsupport_> = {}) const
551  {
552  using Traits = typename Base::Traits;
553 
554  using vector_t = typename Traits::vector_t;
555  using SFC_Ys = typename Traits::SFC_Ys;
556 
557  constexpr auto tile_dstr = typename Base::TileDstr{};
558  static constexpr bool oob_conditional_check = true;
559 
560  // loop over thread tensor space [y0, y1, ...]
561  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
563  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
564  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
565 
566  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
567  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
568 
569  // data index [y0, y1, ...]
570  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
571 
572  // read from distributed tensor
573  vector_t vec_value;
574  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
575  constexpr auto idx_ys = generate_tuple(
576  [&](auto jj) {
577  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
578  : idx_ys_start[jj];
579  },
580  number<Base::NDimY>{});
581  constexpr index_t d =
582  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
583  Traits::PackedSize;
584  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
585  dstr_tensor.get_thread_buffer().template at<d>();
586  });
587 
588  // write into bottom tensor
589  this->get_bottom_tensor_view()
590  .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
591  bottom_tensor_thread_coord, 0, vec_value);
592 
593  // move thread coordinate
594  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
595  {
596  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
597 
598  constexpr auto idx_diff_ps_ys = container_concat(
599  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
600  idx_diff_ys);
601 
603  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
604  }
605  });
606  });
607  }
608 
609  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
610  CK_TILE_DEVICE void
612  dstr_tensor,
615  {
616  using Traits = typename Base::Traits;
617 
618  using vector_t = typename Traits::vector_t;
619  using SFC_Ys = typename Traits::SFC_Ys;
620 
621  constexpr auto tile_dstr = typename Base::TileDstr{};
622 
623  // loop over thread tensor space [y0, y1, ...]
624  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
626  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
627  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
628 
629  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
630  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
631 
632  // data index [y0, y1, ...]
633  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
634 
635  // read from distributed tensor
636  vector_t vec_value;
637 
638  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
639  constexpr auto idx_ys = generate_tuple(
640  [&](auto jj) {
641  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
642  : idx_ys_start[jj];
643  },
644  number<Base::NDimY>{});
645 
646  constexpr index_t d =
647  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
648  Traits::PackedSize;
649 
650  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
651  dstr_tensor.get_thread_buffer().template at<d>();
652  });
653 
654  // write into bottom tensor
655  this->get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
656  bottom_tensor_thread_coord,
657  0,
658  vec_value,
659  bool_constant<oob_conditional_check>{});
660 
661  // move thread coordinate
662  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
663  {
664  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
665 
666  constexpr auto idx_diff_ps_ys = container_concat(
667  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
668  idx_diff_ys);
669 
671  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
672  }
673  });
674  });
675  }
676 
677  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
678  CK_TILE_DEVICE void
680  dstr_tensor,
683  bool_constant<pre_nop> = {}) const
684  {
685  using Traits = typename Base::Traits;
686 
687  using vector_t = typename Traits::vector_t;
688  using SFC_Ys = typename Traits::SFC_Ys;
689 
690  constexpr auto tile_dstr = typename Base::TileDstr{};
691 
692  // loop over thread tensor space [y0, y1, ...]
693  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
695  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
696  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
697 
698  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
699  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
700 
701  // data index [y0, y1, ...]
702  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
703 
704  // read from distributed tensor
705  vector_t vec_value;
706 
707  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
708  constexpr auto idx_ys = generate_tuple(
709  [&](auto jj) {
710  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
711  : idx_ys_start[jj];
712  },
713  number<Base::NDimY>{});
714 
715  constexpr index_t d =
716  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
717  Traits::PackedSize;
718 
719  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
720  dstr_tensor.get_thread_buffer().template at<d>();
721  });
722 
723  // write into bottom tensor
724  this->get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
725  bottom_tensor_thread_coord,
726  0,
727  vec_value,
728  bool_constant<oob_conditional_check>{},
729  bool_constant<pre_nop>{});
730 
731  // move thread coordinate
732  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
733  {
734  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
735 
736  constexpr auto idx_diff_ps_ys = container_concat(
737  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
738  idx_diff_ys);
739 
741  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
742  }
743  });
744  });
745  }
746 
747  // Custom move behavior
749  {
750  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
751  move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
752  pre_computed_coords_(iCoord)(I1),
753  step);
754  });
755  }
756 
758  {
759  // TODO: this use less register for FA, but more register for GEMM
760  // need investigation
761  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
762  this->tile_dstr_.get_ps_ys_to_xs_adaptor(),
765 
766  typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
767  this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
768 
769  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
770  this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
771 
772  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
773  // future load/store() calls (might allocate more registers)
774  using Traits = typename Base::Traits;
775  using SFC_Ys = typename Traits::SFC_Ys;
776 
777  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
778  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
779  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
780 
781  constexpr auto idx_diff_ys =
782  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
783 
784  constexpr auto idx_diff_ps_ys = container_concat(
785  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
786  idx_diff_ys);
787 
789  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
790 
791  pre_computed_coords_(iCoord) =
792  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
793  });
794  }
795 
796  // this contains:
797  // per-thread coordinate for window adaptor
798  // per-thread coordinate for bottom tensor
801 };
802 
803 // TODO: use strategy
804 template <typename TensorView_,
805  typename WindowLengths_,
806  typename StaticTileDistribution_,
807  index_t NumCoord = 1>
808 CK_TILE_DEVICE constexpr auto
809 make_tile_window(const TensorView_& tensor_view,
810  const WindowLengths_& window_lengths,
811  const multi_index<TensorView_::get_num_of_dimension()>& origin,
812  const StaticTileDistribution_& tile_distribution,
813  number<NumCoord> = {})
814 {
815  return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
816  remove_cvref_t<WindowLengths_>,
817  remove_cvref_t<StaticTileDistribution_>,
818  NumCoord>{
819  tensor_view, window_lengths, origin, tile_distribution};
820 }
821 
822 // this version can't be called in a constexpr context
823 template <typename TensorView_,
824  typename WindowLengths_,
825  typename StaticTileDistribution_,
826  index_t NumCoord = 1>
827 CK_TILE_DEVICE auto
829  const WindowLengths_& window_lengths,
830  const multi_index<TensorView_::get_num_of_dimension()>& origin,
831  const StaticTileDistribution_& tile_distribution,
832  number<NumCoord> = {})
833 {
834  auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
835  remove_cvref_t<WindowLengths_>,
836  remove_cvref_t<StaticTileDistribution_>,
837  NumCoord>{
838  tensor_view, window_lengths, origin, tile_distribution};
839  w.init_raw();
840  return w;
841 }
842 
843 template <typename TensorView_,
844  typename WindowLengths_,
845  typename StaticTileDistribution_,
846  index_t NumCoord>
849  WindowLengths_,
850  StaticTileDistribution_,
851  NumCoord>& window,
852  const typename tile_window_with_static_distribution<TensorView_,
853  WindowLengths_,
854  StaticTileDistribution_,
855  NumCoord>::BottomTensorIndex& step)
856 {
857  window.move(step);
858 }
859 
868 template <typename BottomTensorView_, typename WindowLengths_>
870  : public tile_window_base<tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>,
871  BottomTensorView_,
872  WindowLengths_>
873 {
874  using Base =
876  BottomTensorView_,
877  WindowLengths_>;
878 
880 
882  const typename Base::BottomTensorView& bottom_tensor_view,
883  const typename Base::WindowLengths& window_lengths,
884  const typename Base::BottomTensorIndex& window_origin)
885  {
886  this->window_origin_ = window_origin;
887  this->window_lengths_ = window_lengths;
888  this->bottom_tensor_view_ = bottom_tensor_view;
889  }
890 };
891 
892 template <typename TensorView_, typename WindowLengths_>
893 CK_TILE_DEVICE constexpr auto
894 make_tile_window(const TensorView_& tensor_view,
895  const WindowLengths_& window_lengths,
896  const multi_index<TensorView_::get_num_of_dimension()>& origin)
897 {
899  "wrong! lengths should be static");
900 
903  tensor_view, window_lengths, origin};
904 }
905 
906 // duplicate tile window and replace its origin
907 template <typename TensorView, typename WindowLengths>
908 CK_TILE_DEVICE constexpr auto
910  const multi_index<TensorView::get_num_of_dimension()>& origin)
911 {
913  tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), origin};
914 }
915 
916 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
917 CK_TILE_DEVICE constexpr auto
919  const multi_index<TensorView::get_num_of_dimension()>& origin,
920  const StaticTileDistribution& tile_distribution)
921 {
922  return make_tile_window(tile_window.get_bottom_tensor_view(),
923  tile_window.get_window_lengths(),
924  origin,
926 }
927 
928 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
929 CK_TILE_DEVICE constexpr auto
931  const StaticTileDistribution& tile_distribution)
932 {
933  return make_tile_window(tile_window.get_bottom_tensor_view(),
934  tile_window.get_window_lengths(),
935  tile_window.get_window_origin(),
937 }
938 
939 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
940 CK_TILE_DEVICE constexpr auto
942  const StaticTileDistribution& tile_distribution)
943 {
944  auto w = make_tile_window(tile_window.get_bottom_tensor_view(),
945  tile_window.get_window_lengths(),
946  tile_window.get_window_origin(),
948  w.init_raw();
949  return w;
950 }
951 
952 template <typename TensorView_, typename WindowLengths_>
956  step)
957 {
958  window.move(step);
959 }
960 
968 template <typename T>
970 {
971 };
972 
981 template <typename BottomTensorView_,
982  typename WindowLengths_,
983  typename StaticTileDistribution_,
984  index_t NumCoord>
986  tile_window_with_static_distribution<BottomTensorView_,
987  WindowLengths_,
988  StaticTileDistribution_,
989  NumCoord>> : std::true_type
990 {
991 };
992 
1000 template <typename T>
1003 
1011 template <typename T>
1013 {
1014 };
1015 
1022 template <typename BottomTensorView_, typename WindowLengths_>
1024  tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>> : std::true_type
1025 {
1026 };
1027 
1035 template <typename T>
1038 
1039 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr bool is_tile_window_with_static_distribution_v
Helper variable template to check if a type is a tile window with static distribution.
Definition: tile_window.hpp:1001
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto make_tile_window_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, number< NumCoord >={})
Definition: tile_window.hpp:828
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constant< b > bool_constant
Definition: integral_constant.hpp:43
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr bool is_tile_window_with_static_lengths_v
Helper variable template to check if a type is a tile window with static lengths.
Definition: tile_window.hpp:1036
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
bool_constant< false > false_type
Definition: integral_constant.hpp:63
bool_constant< true > true_type
Definition: integral_constant.hpp:62
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Type trait to determine if a type is a tile window with static distribution.
Definition: tile_window.hpp:970
Type trait to determine if a type is a tile window with static lengths.
Definition: tile_window.hpp:1013
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
This class provides description of tile windowed view on the device memory.
Definition: tile_window_base.hpp:31
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window_base.hpp:67
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
This class provides tile (windowed) view and access to the device memory.
Definition: tile_window.hpp:46
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}) const
Definition: tile_window.hpp:548
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex &step)
Definition: tile_window.hpp:748
CK_TILE_DEVICE auto load_transpose() const
Definition: tile_window.hpp:402
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex &)
Definition: tile_window.hpp:757
array< tuple< typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_window.hpp:800
CK_TILE_DEVICE void update_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:679
constexpr CK_TILE_DEVICE tile_window_with_static_distribution()=default
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:344
CK_TILE_DEVICE auto load_transpose(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:415
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:126
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:114
static constexpr auto I0
Definition: tile_window.hpp:56
CK_TILE_DEVICE void load_raw(DstTile &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:189
static constexpr auto I1
Definition: tile_window.hpp:57
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:261
CK_TILE_DEVICE void update(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:611
constexpr CK_TILE_DEVICE tile_window_with_static_distribution(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin, const typename Base::TileDstr &tile_distribution)
Definition: tile_window.hpp:66
static constexpr index_t NumAccessPerCoord
Definition: tile_window.hpp:62
CK_TILE_DEVICE void store(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:478
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:873
constexpr CK_TILE_DEVICE tile_window_with_static_lengths()=default
constexpr CK_TILE_DEVICE tile_window_with_static_lengths(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin)
Definition: tile_window.hpp:881
Definition: tile_window_base.hpp:94
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_window_base.hpp:129