/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window.hpp Source File
tile_window.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
19 
20 namespace ck_tile {
21 
33 template <typename BottomTensorView_,
34  typename WindowLengths_,
35  typename StaticTileDistribution_,
36  index_t NumCoord>
39  tile_window_with_static_distribution<BottomTensorView_,
40  WindowLengths_,
41  StaticTileDistribution_,
42  NumCoord>,
43  BottomTensorView_,
44  WindowLengths_,
45  StaticTileDistribution_>
46 {
48  tile_window_with_static_distribution<BottomTensorView_,
49  WindowLengths_,
50  StaticTileDistribution_,
51  NumCoord>,
52  BottomTensorView_,
53  WindowLengths_,
54  StaticTileDistribution_>;
55 
56  static constexpr auto I0 = number<0>{};
57  static constexpr auto I1 = number<1>{};
58  static_assert(NumCoord == 1);
59 
60  static_assert(Base::Traits::NumAccess % NumCoord == 0,
61  "wrong! # of access is not divisible by NumCoord");
62  static constexpr index_t NumAccessPerCoord = Base::Traits::NumAccess / NumCoord;
63 
65 
67  const typename Base::BottomTensorView& bottom_tensor_view,
68  const typename Base::WindowLengths& window_lengths,
69  const typename Base::BottomTensorIndex& window_origin,
70  const typename Base::TileDstr& tile_distribution)
72  {
73 
74  this->window_origin_ = window_origin;
75  this->window_lengths_ = window_lengths;
76  this->bottom_tensor_view_ = bottom_tensor_view;
77  this->tile_dstr_ = tile_distribution;
78  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
82 
83  typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
84  window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
85 
86  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
87  bottom_tensor_view.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
88 
89  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
90  // future load/store() calls (might allocate more registers)
91  using Traits = typename Base::Traits;
92  using SFC_Ys = typename Traits::SFC_Ys;
93 
94  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
95  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
96  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
97 
98  constexpr auto idx_diff_ys =
99  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
100 
101  constexpr auto idx_diff_ps_ys = container_concat(
102  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
103  idx_diff_ys);
104 
106  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
107 
108  pre_computed_coords_(iCoord) =
109  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
110  });
111  }
112 
113  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
116  {
117  constexpr auto tile_dstr = typename Base::TileDstr{};
118  auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
119  load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
120  return dst_tensor;
121  }
122 
133  template <typename TileWindow_,
134  typename ElementWise_,
135  index_t i_access_unsupport_ = -1,
136  bool oob_conditional_check = true>
137  CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
138  ElementWise_ elementwise,
141  {
142  constexpr auto tile_dstr = typename Base::TileDstr{};
143  auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
144  load(dst_tensor,
145  tile_window,
146  elementwise,
147  number<i_access_unsupport_>{},
148  bool_constant<oob_conditional_check>{});
149  return dst_tensor;
150  }
151 
152  template <typename DistributedTensor,
153  typename TileWindow_,
154  typename ElementWise_,
155  index_t i_access_unsupport_ = -1,
156  bool oob_conditional_check = true>
157  CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
158  const TileWindow_& tile_window,
159  ElementWise_ elementwise,
162  {
163 
164  using Traits = typename Base::Traits;
165  using vector_t = typename Traits::vector_t;
166  using SFC_Ys = typename Traits::SFC_Ys;
167 
168  constexpr auto tile_dstr = typename Base::TileDstr{};
169  constexpr auto sizeOfTuple = TileWindow_::size();
170  // loop over thread tensor space [y0, y1, ...]
171  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
173  auto window_adaptor_thread_coord =
174  tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
175  auto bottom_tensor_thread_coord =
176  tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
177 
178  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
179  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
180 
181  // data index [y0, y1, ...]
182  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
183 
184  // read from bottom tensor
185  const auto idx_vec_value = generate_tuple(
186  [&](auto jj) {
187  return tile_window[number<jj>{}]
188  .get_bottom_tensor_view()
189  .template get_vectorized_elements<vector_t>(
190  bottom_tensor_thread_coord,
191  0,
192  bool_constant<oob_conditional_check>{});
193  },
194  number<sizeOfTuple>{});
195 
196  // write into distributed tensor
197  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
198  constexpr auto idx_ys = generate_tuple(
199  [&](auto jj) {
200  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
201  : idx_ys_start[jj];
202  },
203  number<Base::NDimY>{});
204 
205  constexpr index_t d =
206  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
207  Traits::PackedSize;
208 
210  [&](auto&&... t) {
211  elementwise(dst_tensor.get_thread_buffer().template at<d>(),
212  t.template get_as<
213  typename Base::DataType>()[j / Traits::PackedSize]...);
214  },
215  idx_vec_value);
216  });
217  // move thread coordinate
218  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
219  {
220  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
221 
222  constexpr auto idx_diff_ps_ys = container_concat(
223  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
224  idx_diff_ys);
225 
227  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
228  }
229  });
230  });
231  }
232 
233  template <typename DistributedTensor,
234  index_t i_access_unsupport_ = -1,
235  bool oob_conditional_check = true>
236  CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
239  {
240  using Traits = typename Base::Traits;
241  using vector_t = typename Traits::vector_t;
242  using SFC_Ys = typename Traits::SFC_Ys;
243 
244  constexpr auto tile_dstr = typename Base::TileDstr{};
245 
246  // loop over thread tensor space [y0, y1, ...]
247  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
249  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
250  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
251 
252  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
253  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
254 
255  // data index [y0, y1, ...]
256  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
257 
258  // read from bottom tensor
259  const vector_t vec_value =
260  this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
261  bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
262  // write into distributed tensor
263  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
264  constexpr auto idx_ys = generate_tuple(
265  [&](auto jj) {
266  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
267  : idx_ys_start[jj];
268  },
269  number<Base::NDimY>{});
270 
271  constexpr index_t d =
272  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
273  Traits::PackedSize;
274 
275  dst_tensor.get_thread_buffer().template at<d>() =
276  vec_value
277  .template get_as<typename Base::DataType>()[j / Traits::PackedSize];
278  });
279  // move thread coordinate
280  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
281  {
282  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
283 
284  constexpr auto idx_diff_ps_ys = container_concat(
285  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
286  idx_diff_ys);
287 
289  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
290  }
291  });
292  });
293  }
294 
295  template <typename DstTile,
296  index_t i_access_unsupport_ = -1,
297  bool oob_conditional_check = true,
298  bool pre_nop = false>
299  CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
302  bool_constant<pre_nop> = {}) const
303  {
304  using Traits = typename Base::Traits;
305  using vector_t = typename Traits::vector_t;
306  using SFC_Ys = typename Traits::SFC_Ys;
307  static constexpr index_t YElementSize =
308  typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
309  static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0);
310  using vectorized_tbuf =
311  array<vector_t, YElementSize / (Traits::PackedSize * Traits::ScalarPerVector)>;
312 
313  constexpr auto tile_dstr = typename Base::TileDstr{};
314 
315  auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
316 
317  // loop over thread tensor space [y0, y1, ...]
318  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
320  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
321  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
322 
323  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
324  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
325  constexpr auto pre_nop_ = [&]() {
326  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
327  return bool_constant<true>{};
328  else
329  return bool_constant<false>{};
330  }();
331 
332  // data index [y0, y1, ...]
333  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
334  constexpr index_t d =
335  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
336  Traits::PackedSize;
337  static_assert(d % Traits::ScalarPerVector == 0);
338 
339  this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
340  dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
341  bottom_tensor_thread_coord,
342  0 ,
343  bool_constant<oob_conditional_check>{},
344  pre_nop_);
345 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
346  CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
347  asm volatile(
348  ""); // this is starting from rocm-6.2, but same sympton, reuse this flag
349 #endif
350  // move thread coordinate
351  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
352  {
353  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
354 
355  constexpr auto idx_diff_ps_ys = container_concat(
356  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
357  idx_diff_ys);
358 
360  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
361  }
362  });
363  });
364  }
365 
366  // TODO: currently async load only implemented in inline asm
367  template <typename LdsTileWindow_,
368  index_t i_access_unsupport_ = -1,
369  bool oob_conditional_check = true,
370  bool pre_nop = false>
371  CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
374  bool_constant<pre_nop> = {}) const
375  {
376  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
377  // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
378  using LdsDataType = typename LdsTileWindow::DataType;
379  // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
380 
381  // issues * warps * lanes
382  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
383 
384  const index_t size_per_buf =
385  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
386  make_tuple(number<0>{}, number<0>{}, number<0>{})) *
387  sizeof(LdsDataType);
388 
389  const index_t size_per_wave =
390  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
391  make_tuple(number<0>{}, number<1>{}, number<0>{})) *
392  sizeof(LdsDataType) -
393  size_per_buf;
394 
395  const index_t size_per_issue =
396  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
397  make_tuple(number<1>{}, number<0>{}, number<0>{})) *
398  sizeof(LdsDataType) -
399  size_per_buf;
400 
401  // Use VALU so the compiler can optimize redundant/repeated computations
402  const index_t m0_init_value =
403  size_per_buf + size_per_wave * get_warp_id(/*ReturnSgpr=*/bool_constant<false>{});
405  amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
406 
407  using Traits = typename Base::Traits;
408 
409  using vector_t = typename Traits::vector_t;
410  using SFC_Ys = typename Traits::SFC_Ys;
411 
412  LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
413 
414  // loop over thread tensor space [y0, y1, ...]
415  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
417  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
418  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
419 
420  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
421  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
422  constexpr auto pre_nop_ = [&]() {
423  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
424  return bool_constant<true>{};
425  else
426  return bool_constant<false>{};
427  }();
428 
429  // read from bottom tensor
430  this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
431  smem, bottom_tensor_thread_coord, 0, pre_nop_);
432 
433  // move thread coordinate
434  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
435  {
436  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
437 
438  constexpr auto idx_diff_ps_ys = container_concat(
439  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
440  idx_diff_ys);
441 
443  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
444 
445  m0_inc_with_memory(size_per_issue);
446  }
447  });
448  });
449  }
450 
451  template <typename LdsTileWindow_,
452  index_t i_access_unsupport_ = -1,
453  bool oob_conditional_check = true>
454  CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
457  {
458  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
459  using LdsDataType = typename LdsTileWindow::DataType;
460  using Traits = typename Base::Traits;
461 
462  using vector_t = typename Traits::vector_t;
463  using SFC_Ys = typename Traits::SFC_Ys;
464 
465  // Precompute invariant values outside loops
466  const auto window_origin = lds_tile.get_window_origin();
467  const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
468  const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
469  auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
470 
471  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
472  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
473  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
474 
475  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
476  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
477 
478  // Use precomputed window origin
479  auto lds_bottom_tensor_thread_idx =
480  window_origin + window_adaptor_thread_coord.get_bottom_index();
481 
482  // Use precomputed tensor descriptor
483  const auto lds_coord =
484  make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
485 
486  // Calculate SMEM address using base pointer
487  CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
488 
489  // Write into bottom tensor
490  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
491  smem,
492  bottom_tensor_thread_coord,
493  number<0>{},
494  bool_constant<oob_conditional_check>{});
495 
496  // Move thread coordinate if not last access
497  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
498  {
499  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
500  constexpr auto idx_diff_ps_ys = container_concat(
501  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
502  idx_diff_ys);
503 
505  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
506  }
507  });
508  });
509  }
510 
511  template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
513  {
514  constexpr auto tile_dstr = typename Base::TileDstr{};
515  auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
516  this->template load_transpose<Policy>(
518  return dst_tensor;
519  }
520 
521  template <typename Policy,
522  typename DistributedTensor,
523  index_t i_access_unsupport_ = -1,
524  bool oob_conditional_check = true>
525  CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor,
528  {
529  using Traits = typename Base::Traits;
530  using vector_t = typename Traits::vector_t;
531  using SFC_Ys = typename Traits::SFC_Ys;
532 
533  constexpr auto tile_dstr = typename Base::TileDstr{};
534 
535  constexpr auto group_func = Policy::group_func;
536 
537  // loop over thread tensor space [y0, y1, ...]
538  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
540  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
541  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
542 
543  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
544  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
545 
546  // data index [y0, y1, ...]
547  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
548 
549  // read from bottom tensor
550  const vector_t vec_value =
551  this->get_bottom_tensor_view()
552  .template get_transpose_vectorized_elements<vector_t>(
553  bottom_tensor_thread_coord, 0);
554  // write into distributed tensor
555  static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
556  constexpr auto orig_idx_ys = generate_tuple(
557  [&](auto jj) {
558  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
559  : idx_ys_start[jj];
560  },
561  number<Base::NDimY>{});
562 
563  constexpr auto grouped_idx_ys = group_func(orig_idx_ys);
564 
565  constexpr index_t linear_distributed_index =
566  tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
567 
568  dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
569  vec_value.template get_as<typename Base::DataType>()[j];
570  });
571  // move thread coordinate
572  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
573  {
574  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
575 
576  constexpr auto idx_diff_ps_ys = container_concat(
577  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
578  idx_diff_ys);
579 
581  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
582  }
583  });
584  });
585  }
586 
587  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
589  typename Base::TileDstr>& dstr_tensor,
592  {
593  using Traits = typename Base::Traits;
594 
595  using vector_t = typename Traits::vector_t;
596  using SFC_Ys = typename Traits::SFC_Ys;
597 
598  constexpr auto tile_dstr = typename Base::TileDstr{};
599 
600  // loop over thread tensor space [y0, y1, ...]
601  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
602  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
603  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
604 
605  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
606  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
607 
608  // data index [y0, y1, ...]
609  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
610 
611  // read from distributed tensor
612  // vector_type_t vec;
613  vector_t vec_value;
614 
615  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
616  constexpr auto idx_ys = generate_tuple(
617  [&](auto jj) {
618  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
619  : idx_ys_start[jj];
620  },
621  number<Base::NDimY>{});
622 
623  constexpr index_t d =
624  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
625  Traits::PackedSize;
626 
627  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
628  dstr_tensor.get_thread_buffer().template at<d>();
629  });
630 
631  // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
632 
633  // write into bottom tensor
634  this->get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
635  bottom_tensor_thread_coord,
636  0,
637  vec_value,
638  bool_constant<oob_conditional_check>{});
639 
640  // move thread coordinate
641  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
642  {
643  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
644 
645  constexpr auto idx_diff_ps_ys = container_concat(
646  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
647  idx_diff_ys);
648 
650  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
651  }
652  });
653  });
654  }
655 
656  template <index_t i_access_unsupport_ = -1>
657  CK_TILE_DEVICE void
659  dstr_tensor,
660  number<i_access_unsupport_> = {}) const
661  {
662  using Traits = typename Base::Traits;
663 
664  using vector_t = typename Traits::vector_t;
665  using SFC_Ys = typename Traits::SFC_Ys;
666 
667  constexpr auto tile_dstr = typename Base::TileDstr{};
668  static constexpr bool oob_conditional_check = true;
669 
670  // loop over thread tensor space [y0, y1, ...]
671  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
673  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
674  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
675 
676  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
677  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
678 
679  // data index [y0, y1, ...]
680  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
681 
682  // read from distributed tensor
683  vector_t vec_value;
684  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
685  constexpr auto idx_ys = generate_tuple(
686  [&](auto jj) {
687  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
688  : idx_ys_start[jj];
689  },
690  number<Base::NDimY>{});
691  constexpr index_t d =
692  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
693  Traits::PackedSize;
694  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
695  dstr_tensor.get_thread_buffer().template at<d>();
696  });
697 
698  // write into bottom tensor
699  this->get_bottom_tensor_view()
700  .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
701  bottom_tensor_thread_coord, 0, vec_value);
702 
703  // move thread coordinate
704  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
705  {
706  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
707 
708  constexpr auto idx_diff_ps_ys = container_concat(
709  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
710  idx_diff_ys);
711 
713  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
714  }
715  });
716  });
717  }
718 
719  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
720  CK_TILE_DEVICE void
722  dstr_tensor,
725  {
726  using Traits = typename Base::Traits;
727 
728  using vector_t = typename Traits::vector_t;
729  using SFC_Ys = typename Traits::SFC_Ys;
730 
731  constexpr auto tile_dstr = typename Base::TileDstr{};
732 
733  // loop over thread tensor space [y0, y1, ...]
734  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
736  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
737  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
738 
739  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
740  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
741 
742  // data index [y0, y1, ...]
743  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
744 
745  // read from distributed tensor
746  vector_t vec_value;
747 
748  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
749  constexpr auto idx_ys = generate_tuple(
750  [&](auto jj) {
751  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
752  : idx_ys_start[jj];
753  },
754  number<Base::NDimY>{});
755 
756  constexpr index_t d =
757  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
758  Traits::PackedSize;
759 
760  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
761  dstr_tensor.get_thread_buffer().template at<d>();
762  });
763 
764  // write into bottom tensor
765  this->get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
766  bottom_tensor_thread_coord,
767  0,
768  vec_value,
769  bool_constant<oob_conditional_check>{});
770 
771  // move thread coordinate
772  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
773  {
774  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
775 
776  constexpr auto idx_diff_ps_ys = container_concat(
777  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
778  idx_diff_ys);
779 
781  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
782  }
783  });
784  });
785  }
786 
787  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
788  CK_TILE_DEVICE void
790  dstr_tensor,
793  bool_constant<pre_nop> = {}) const
794  {
795  using Traits = typename Base::Traits;
796 
797  using vector_t = typename Traits::vector_t;
798  using SFC_Ys = typename Traits::SFC_Ys;
799 
800  constexpr auto tile_dstr = typename Base::TileDstr{};
801 
802  // loop over thread tensor space [y0, y1, ...]
803  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
805  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
806  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
807 
808  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
809  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
810 
811  // data index [y0, y1, ...]
812  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
813 
814  // read from distributed tensor
815  vector_t vec_value;
816 
817  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
818  constexpr auto idx_ys = generate_tuple(
819  [&](auto jj) {
820  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
821  : idx_ys_start[jj];
822  },
823  number<Base::NDimY>{});
824 
825  constexpr index_t d =
826  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
827  Traits::PackedSize;
828 
829  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
830  dstr_tensor.get_thread_buffer().template at<d>();
831  });
832 
833  // write into bottom tensor
834  this->get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
835  bottom_tensor_thread_coord,
836  0,
837  vec_value,
838  bool_constant<oob_conditional_check>{},
839  bool_constant<pre_nop>{});
840 
841  // move thread coordinate
842  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
843  {
844  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
845 
846  constexpr auto idx_diff_ps_ys = container_concat(
847  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
848  idx_diff_ys);
849 
851  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
852  }
853  });
854  });
855  }
856 
857  // Custom move behavior
859  {
860  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
861  move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
862  pre_computed_coords_(iCoord)(I1),
863  step);
864  });
865  }
866 
868  {
869  // TODO: this use less register for FA, but more register for GEMM
870  // need investigation
871  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
872  this->tile_dstr_.get_ps_ys_to_xs_adaptor(),
875 
876  typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
877  this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
878 
879  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
880  this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
881 
882  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
883  // future load/store() calls (might allocate more registers)
884  using Traits = typename Base::Traits;
885  using SFC_Ys = typename Traits::SFC_Ys;
886 
887  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
888  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
889  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
890 
891  constexpr auto idx_diff_ys =
892  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
893 
894  constexpr auto idx_diff_ps_ys = container_concat(
895  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
896  idx_diff_ys);
897 
899  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
900 
901  pre_computed_coords_(iCoord) =
902  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
903  });
904  }
905 
906  // this contains:
907  // per-thread coordinate for window adaptor
908  // per-thread coordinate for bottom tensor
911 };
912 
913 // TODO: use strategy
914 template <typename TensorView_,
915  typename WindowLengths_,
916  typename StaticTileDistribution_,
917  index_t NumCoord = 1>
918 CK_TILE_DEVICE constexpr auto
919 make_tile_window(const TensorView_& tensor_view,
920  const WindowLengths_& window_lengths,
921  const multi_index<TensorView_::get_num_of_dimension()>& origin,
922  const StaticTileDistribution_& tile_distribution,
923  number<NumCoord> = {})
924 {
925  return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
926  remove_cvref_t<WindowLengths_>,
927  remove_cvref_t<StaticTileDistribution_>,
928  NumCoord>{
929  tensor_view, window_lengths, origin, tile_distribution};
930 }
931 
932 // this version can't be called in a constexpr context
933 template <typename TensorView_,
934  typename WindowLengths_,
935  typename StaticTileDistribution_,
936  index_t NumCoord = 1>
937 CK_TILE_DEVICE auto
939  const WindowLengths_& window_lengths,
940  const multi_index<TensorView_::get_num_of_dimension()>& origin,
941  const StaticTileDistribution_& tile_distribution,
942  number<NumCoord> = {})
943 {
944  auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
945  remove_cvref_t<WindowLengths_>,
946  remove_cvref_t<StaticTileDistribution_>,
947  NumCoord>{
948  tensor_view, window_lengths, origin, tile_distribution};
949  w.init_raw();
950  return w;
951 }
952 
953 template <typename TensorView_,
954  typename WindowLengths_,
955  typename StaticTileDistribution_,
956  index_t NumCoord>
959  WindowLengths_,
960  StaticTileDistribution_,
961  NumCoord>& window,
962  const typename tile_window_with_static_distribution<TensorView_,
963  WindowLengths_,
964  StaticTileDistribution_,
965  NumCoord>::BottomTensorIndex& step)
966 {
967  window.move(step);
968 }
969 
970 template <typename TensorView_,
971  typename WindowLengths_,
972  typename StaticTileDistribution_,
973  index_t NumCoord>
976  WindowLengths_,
977  StaticTileDistribution_,
978  NumCoord>>& window,
979  const typename tile_window_with_static_distribution<TensorView_,
980  WindowLengths_,
981  StaticTileDistribution_,
982  NumCoord>::BottomTensorIndex& step)
983 {
984  using T = tuple<tile_window_with_static_distribution<TensorView_,
985  WindowLengths_,
986  StaticTileDistribution_,
987  NumCoord>>;
988 
989  static constexpr auto N = T::size();
990  static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
991 }
992 
993 template <typename TileWindowWithStaticDistributionType,
994  typename StepType,
995  typename std::enable_if_t<
997 CK_TILE_DEVICE void move_tile_window(TileWindowWithStaticDistributionType& window, StepType& step)
998 {
999  static constexpr auto N = TileWindowWithStaticDistributionType::size();
1000  static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
1001 }
1002 
1011 template <typename BottomTensorView_, typename WindowLengths_>
1013  : public tile_window_base<tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>,
1014  BottomTensorView_,
1015  WindowLengths_>
1016 {
1017  using Base =
1019  BottomTensorView_,
1020  WindowLengths_>;
1021 
1023 
1025  const typename Base::BottomTensorView& bottom_tensor_view,
1026  const typename Base::WindowLengths& window_lengths,
1027  const typename Base::BottomTensorIndex& window_origin)
1028  {
1029  this->window_origin_ = window_origin;
1030  this->window_lengths_ = window_lengths;
1031  this->bottom_tensor_view_ = bottom_tensor_view;
1032  }
1033 
1047  template <typename DataType>
1049  index_t end_i,
1050  index_t start_j,
1051  index_t end_j,
1052  const char* label = "") const
1053  {
1054  const auto& tensor_view = this->get_bottom_tensor_view();
1055  const auto window_origin = this->get_window_origin();
1056 
1057  printf("%s Window Range [%d:%d, %d:%d] (origin: %d, %d):\n",
1058  label,
1059  start_i,
1060  end_i - 1,
1061  start_j,
1062  end_j - 1,
1063  window_origin[0],
1064  window_origin[1]);
1065 
1066  for(index_t i = start_i; i < end_i; i++)
1067  {
1068  for(index_t j = start_j; j < end_j; j++)
1069  {
1070  // Create coordinate for this element relative to window origin
1071  auto coord =
1073  make_tuple(window_origin[0] + i, window_origin[1] + j));
1074 
1075  // Get the element using thread buffer type directly
1076  using ThreadBuf = thread_buffer<DataType, 2>;
1077  auto buf = tensor_view.template get_vectorized_elements<ThreadBuf>(coord, 0);
1078  auto value = buf.at(number<0>{}); // Extract first element from thread buffer
1079  printf(" %s[%d,%d] = %f", label, i, j, static_cast<float>(value));
1080  }
1081  printf("\n");
1082  }
1083  printf("\n");
1084  }
1085 };
1086 
1087 template <typename TensorView_, typename WindowLengths_>
1088 CK_TILE_DEVICE constexpr auto
1089 make_tile_window(const TensorView_& tensor_view,
1090  const WindowLengths_& window_lengths,
1091  const multi_index<TensorView_::get_num_of_dimension()>& origin)
1092 {
1094  "wrong! lengths should be static");
1095 
1098  tensor_view, window_lengths, origin};
1099 }
1100 
1101 // duplicate tile window and replace its origin
1102 template <typename TensorView, typename WindowLengths>
1103 CK_TILE_DEVICE constexpr auto
1105  const multi_index<TensorView::get_num_of_dimension()>& origin)
1106 {
1108  tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), origin};
1109 }
1110 
1111 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
1112 CK_TILE_DEVICE constexpr auto
1114  const multi_index<TensorView::get_num_of_dimension()>& origin,
1115  const StaticTileDistribution& tile_distribution)
1116 {
1117  return make_tile_window(tile_window.get_bottom_tensor_view(),
1118  tile_window.get_window_lengths(),
1119  origin,
1121 }
1122 
1123 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
1124 CK_TILE_DEVICE constexpr auto
1126  const StaticTileDistribution& tile_distribution)
1127 {
1128  return make_tile_window(tile_window.get_bottom_tensor_view(),
1129  tile_window.get_window_lengths(),
1130  tile_window.get_window_origin(),
1132 }
1133 
1134 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
1135 CK_TILE_DEVICE constexpr auto
1137  const StaticTileDistribution& tile_distribution)
1138 {
1139  auto w = make_tile_window(tile_window.get_bottom_tensor_view(),
1140  tile_window.get_window_lengths(),
1141  tile_window.get_window_origin(),
1143  w.init_raw();
1144  return w;
1145 }
1146 
1147 template <typename TensorView_, typename WindowLengths_>
1151  step)
1152 {
1153  window.move(step);
1154 }
1155 
1163 template <typename T>
1165 {
1166 };
1167 
1176 template <typename BottomTensorView_,
1177  typename WindowLengths_,
1178  typename StaticTileDistribution_,
1179  index_t NumCoord>
1181  tile_window_with_static_distribution<BottomTensorView_,
1182  WindowLengths_,
1183  StaticTileDistribution_,
1184  NumCoord>> : std::true_type
1185 {
1186 };
1187 
1195 template <typename T>
1198 
1206 template <typename T>
1208 {
1209 };
1210 
1217 template <typename BottomTensorView_, typename WindowLengths_>
1219  tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>> : std::true_type
1220 {
1221 };
1222 
1230 template <typename T>
1233 
1234 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition: tuple.hpp:526
constexpr bool is_tile_window_with_static_distribution_v
Helper variable template to check if a type is a tile window with static distribution.
Definition: tile_window.hpp:1196
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto make_tile_window_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, number< NumCoord >={})
Definition: tile_window.hpp:938
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constant< b > bool_constant
Definition: integral_constant.hpp:43
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr bool is_tile_window_with_static_lengths_v
Helper variable template to check if a type is a tile window with static lengths.
Definition: tile_window.hpp:1231
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
bool_constant< false > false_type
Definition: integral_constant.hpp:63
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
bool_constant< true > true_type
Definition: integral_constant.hpp:62
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Type trait to determine if a type is a tile window with static distribution.
Definition: tile_window.hpp:1165
Type trait to determine if a type is a tile window with static lengths.
Definition: tile_window.hpp:1208
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:81
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
constexpr CK_TILE_HOST_DEVICE auto & get_tensor_descriptor() const
Definition: tensor_view.hpp:61
Definition: debug.hpp:67
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
This class provides description of tile windowed view on the device memory.
Definition: tile_window_base.hpp:31
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window_base.hpp:67
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
This class provides tile (windowed) view and access to the device memory.
Definition: tile_window.hpp:46
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}) const
Definition: tile_window.hpp:658
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex &step)
Definition: tile_window.hpp:858
CK_TILE_DEVICE auto load_transpose() const
Definition: tile_window.hpp:512
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex &)
Definition: tile_window.hpp:867
array< tuple< typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_window.hpp:910
CK_TILE_DEVICE void update_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:789
constexpr CK_TILE_DEVICE tile_window_with_static_distribution()=default
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:454
CK_TILE_DEVICE auto load_transpose(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:525
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:157
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:236
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:114
static constexpr auto I0
Definition: tile_window.hpp:56
CK_TILE_DEVICE auto load(const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Load tile with elementwise function.
Definition: tile_window.hpp:137
CK_TILE_DEVICE void load_raw(DstTile &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:299
static constexpr auto I1
Definition: tile_window.hpp:57
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:371
CK_TILE_DEVICE void update(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:721
constexpr CK_TILE_DEVICE tile_window_with_static_distribution(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin, const typename Base::TileDstr &tile_distribution)
Definition: tile_window.hpp:66
static constexpr index_t NumAccessPerCoord
Definition: tile_window.hpp:62
CK_TILE_DEVICE void store(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:588
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1016
constexpr CK_TILE_DEVICE tile_window_with_static_lengths()=default
CK_TILE_DEVICE void print_tile_window_range(index_t start_i, index_t end_i, index_t start_j, index_t end_j, const char *label="") const
Definition: tile_window.hpp:1048
constexpr CK_TILE_DEVICE tile_window_with_static_lengths(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin)
Definition: tile_window.hpp:1024
Definition: tile_window_base.hpp:94
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_window_base.hpp:129
Definition: tuple.hpp:192