33 template <
typename BottomTensorView_,
34 typename WindowLengths_,
35 typename StaticTileDistribution_,
39 tile_window_with_static_distribution<BottomTensorView_,
41 StaticTileDistribution_,
45 StaticTileDistribution_>
50 StaticTileDistribution_,
54 StaticTileDistribution_>;
58 static_assert(NumCoord == 1);
60 static_assert(Base::Traits::NumAccess % NumCoord == 0,
61 "wrong! # of access is not divisible by NumCoord");
84 window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
87 bottom_tensor_view.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
91 using Traits =
typename Base::Traits;
92 using SFC_Ys =
typename Traits::SFC_Ys;
95 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
96 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
98 constexpr
auto idx_diff_ys =
106 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
109 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
113 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
118 auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
119 load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
123 template <
typename DistributedTensor,
124 index_t i_access_unsupport_ = -1,
125 bool oob_conditional_check =
true>
130 using Traits =
typename Base::Traits;
131 using vector_t =
typename Traits::vector_t;
132 using SFC_Ys =
typename Traits::SFC_Ys;
137 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
142 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
143 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
146 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
149 const vector_t vec_value =
151 bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
153 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
156 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
159 number<Base::NDimY>{});
162 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
165 dst_tensor.get_thread_buffer().template at<d>() =
167 .template get_as<typename Base::DataType>()[j / Traits::PackedSize];
172 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
175 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
179 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
185 template <
typename DstTile,
186 index_t i_access_unsupport_ = -1,
187 bool oob_conditional_check =
true,
188 bool pre_nop =
false>
192 bool_constant<pre_nop> = {})
const
194 using Traits =
typename Base::Traits;
195 using vector_t =
typename Traits::vector_t;
196 using SFC_Ys =
typename Traits::SFC_Ys;
197 static constexpr
index_t YElementSize =
198 typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
199 static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0);
200 using vectorized_tbuf =
201 array<vector_t, YElementSize / (Traits::PackedSize * Traits::ScalarPerVector)>;
205 auto& dst_vec_tbuf =
reinterpret_cast<vectorized_tbuf&
>(dst_tensor.get_thread_buffer());
208 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
213 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
214 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
215 constexpr
auto pre_nop_ = [&]() {
216 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
219 return bool_constant<false>{};
223 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
225 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
227 static_assert(d % Traits::ScalarPerVector == 0);
230 dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
231 bottom_tensor_thread_coord,
233 bool_constant<oob_conditional_check>{},
235 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
236 CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
243 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
246 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
250 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
257 template <
typename LdsTileWindow_,
258 index_t i_access_unsupport_ = -1,
259 bool oob_conditional_check =
true,
260 bool pre_nop =
false>
264 bool_constant<pre_nop> = {})
const
266 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
268 using LdsDataType =
typename LdsTileWindow::DataType;
272 static_assert(LdsTileWindow::get_num_of_dimension() == 3);
275 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
276 make_tuple(number<0>{}, number<0>{}, number<0>{})) *
280 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
281 make_tuple(number<0>{}, number<1>{}, number<0>{})) *
282 sizeof(LdsDataType) -
286 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
287 make_tuple(number<1>{}, number<0>{}, number<0>{})) *
288 sizeof(LdsDataType) -
293 size_per_buf + size_per_wave * get_warp_id(bool_constant<false>{});
295 __builtin_amdgcn_readfirstlane(m0_init_value));
297 using Traits =
typename Base::Traits;
299 using vector_t =
typename Traits::vector_t;
300 using SFC_Ys =
typename Traits::SFC_Ys;
302 LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
305 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
310 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
311 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
312 constexpr
auto pre_nop_ = [&]() {
313 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
316 return bool_constant<false>{};
321 smem, bottom_tensor_thread_coord, 0, pre_nop_);
326 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
329 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
333 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
341 template <
typename LdsTileWindow_,
342 index_t i_access_unsupport_ = -1,
343 bool oob_conditional_check =
true>
348 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
349 using LdsDataType =
typename LdsTileWindow::DataType;
350 using Traits =
typename Base::Traits;
352 using vector_t =
typename Traits::vector_t;
353 using SFC_Ys =
typename Traits::SFC_Ys;
356 const auto window_origin = lds_tile.get_window_origin();
357 const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
358 const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
359 auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
361 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
365 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
366 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
369 auto lds_bottom_tensor_thread_idx =
370 window_origin + window_adaptor_thread_coord.get_bottom_index();
373 const auto lds_coord =
377 CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
382 bottom_tensor_thread_coord,
384 bool_constant<oob_conditional_check>{});
389 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
391 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
395 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
401 template <
typename Policy,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
405 auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
406 this->
template load_transpose<Policy>(
411 template <
typename Policy,
412 typename DistributedTensor,
413 index_t i_access_unsupport_ = -1,
414 bool oob_conditional_check =
true>
419 using Traits =
typename Base::Traits;
420 using vector_t =
typename Traits::vector_t;
421 using SFC_Ys =
typename Traits::SFC_Ys;
425 constexpr
auto group_func = Policy::group_func;
428 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
433 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
434 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
437 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
440 const vector_t vec_value =
442 .template get_transpose_vectorized_elements<vector_t>(
443 bottom_tensor_thread_coord, 0);
445 static_for<0, Traits::ScalarPerVector, 1>{}([&](
auto j) {
448 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
451 number<Base::NDimY>{});
453 constexpr
auto grouped_idx_ys = group_func(orig_idx_ys);
455 constexpr
index_t linear_distributed_index =
456 tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
458 dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
459 vec_value.template get_as<typename Base::DataType>()[j];
464 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
467 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
471 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
477 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
483 using Traits =
typename Base::Traits;
485 using vector_t =
typename Traits::vector_t;
486 using SFC_Ys =
typename Traits::SFC_Ys;
491 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
495 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
496 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
499 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
505 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
508 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
511 number<Base::NDimY>{});
514 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
517 vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
518 dstr_tensor.get_thread_buffer().template at<d>();
525 bottom_tensor_thread_coord,
528 bool_constant<oob_conditional_check>{});
533 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
536 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
540 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
546 template <
index_t i_access_unsupport_ = -1>
552 using Traits =
typename Base::Traits;
554 using vector_t =
typename Traits::vector_t;
555 using SFC_Ys =
typename Traits::SFC_Ys;
558 static constexpr
bool oob_conditional_check =
true;
561 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
566 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
567 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
570 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
574 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
577 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
580 number<Base::NDimY>{});
582 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
584 vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
590 .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
591 bottom_tensor_thread_coord, 0, vec_value);
596 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
599 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
603 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
609 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
616 using Traits =
typename Base::Traits;
618 using vector_t =
typename Traits::vector_t;
619 using SFC_Ys =
typename Traits::SFC_Ys;
624 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
629 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
630 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
633 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
638 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
641 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
644 number<Base::NDimY>{});
647 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
650 vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
656 bottom_tensor_thread_coord,
659 bool_constant<oob_conditional_check>{});
664 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
667 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
671 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
677 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true,
bool pre_nop>
683 bool_constant<pre_nop> = {})
const
685 using Traits =
typename Base::Traits;
687 using vector_t =
typename Traits::vector_t;
688 using SFC_Ys =
typename Traits::SFC_Ys;
693 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
698 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
699 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
702 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
707 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
710 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
713 number<Base::NDimY>{});
716 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
719 vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
725 bottom_tensor_thread_coord,
728 bool_constant<oob_conditional_check>{},
729 bool_constant<pre_nop>{});
734 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
737 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
741 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
767 this->
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
774 using Traits =
typename Base::Traits;
775 using SFC_Ys =
typename Traits::SFC_Ys;
778 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
779 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
781 constexpr
auto idx_diff_ys =
789 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
792 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
804 template <
typename TensorView_,
805 typename WindowLengths_,
806 typename StaticTileDistribution_,
810 const WindowLengths_& window_lengths,
811 const multi_index<TensorView_::get_num_of_dimension()>& origin,
815 return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
816 remove_cvref_t<WindowLengths_>,
817 remove_cvref_t<StaticTileDistribution_>,
819 tensor_view, window_lengths, origin, tile_distribution};
823 template <
typename TensorView_,
824 typename WindowLengths_,
825 typename StaticTileDistribution_,
829 const WindowLengths_& window_lengths,
830 const multi_index<TensorView_::get_num_of_dimension()>& origin,
834 auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
835 remove_cvref_t<WindowLengths_>,
836 remove_cvref_t<StaticTileDistribution_>,
838 tensor_view, window_lengths, origin, tile_distribution};
843 template <
typename TensorView_,
844 typename WindowLengths_,
845 typename StaticTileDistribution_,
850 StaticTileDistribution_,
854 StaticTileDistribution_,
855 NumCoord>::BottomTensorIndex& step)
868 template <
typename BottomTensorView_,
typename WindowLengths_>
870 :
public tile_window_base<tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>,
892 template <
typename TensorView_,
typename WindowLengths_>
895 const WindowLengths_& window_lengths,
896 const multi_index<TensorView_::get_num_of_dimension()>& origin)
899 "wrong! lengths should be static");
907 template <
typename TensorView,
typename WindowLengths>
910 const multi_index<TensorView::get_num_of_dimension()>& origin)
916 template <
typename TensorView,
typename WindowLengths,
typename StaticTileDistribution>
919 const multi_index<TensorView::get_num_of_dimension()>& origin,
928 template <
typename TensorView,
typename WindowLengths,
typename StaticTileDistribution>
939 template <
typename TensorView,
typename WindowLengths,
typename StaticTileDistribution>
952 template <
typename TensorView_,
typename WindowLengths_>
968 template <
typename T>
981 template <
typename BottomTensorView_,
982 typename WindowLengths_,
983 typename StaticTileDistribution_,
988 StaticTileDistribution_,
1000 template <
typename T>
1011 template <
typename T>
1022 template <
typename BottomTensorView_,
typename WindowLengths_>
1035 template <
typename T>
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr bool is_tile_window_with_static_distribution_v
Helper variable template to check if a type is a tile window with static distribution.
Definition: tile_window.hpp:1001
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto make_tile_window_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, number< NumCoord >={})
Definition: tile_window.hpp:828
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constant< b > bool_constant
Definition: integral_constant.hpp:43
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr bool is_tile_window_with_static_lengths_v
Helper variable template to check if a type is a tile window with static lengths.
Definition: tile_window.hpp:1036
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
bool_constant< false > false_type
Definition: integral_constant.hpp:63
bool_constant< true > true_type
Definition: integral_constant.hpp:62
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Type trait to determine if a type is a tile window with static distribution.
Definition: tile_window.hpp:970
Type trait to determine if a type is a tile window with static lengths.
Definition: tile_window.hpp:1013
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
This class provides description of tile windowed view on the device memory.
Definition: tile_window_base.hpp:31
BottomTensorView bottom_tensor_view_
Definition: tile_window_base.hpp:85
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_window_base.hpp:36
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
BottomTensorIndex window_origin_
Definition: tile_window_base.hpp:79
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window_base.hpp:67
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_window_base.hpp:33
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_window_base.hpp:34
WindowLengths window_lengths_
Definition: tile_window_base.hpp:81
This class provides tile (windowed) view and access to the device memory.
Definition: tile_window.hpp:46
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}) const
Definition: tile_window.hpp:548
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex &step)
Definition: tile_window.hpp:748
CK_TILE_DEVICE auto load_transpose() const
Definition: tile_window.hpp:402
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex &)
Definition: tile_window.hpp:757
array< tuple< typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_window.hpp:800
CK_TILE_DEVICE void update_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:679
constexpr CK_TILE_DEVICE tile_window_with_static_distribution()=default
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:344
CK_TILE_DEVICE auto load_transpose(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:415
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:126
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:114
static constexpr auto I0
Definition: tile_window.hpp:56
CK_TILE_DEVICE void load_raw(DstTile &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:189
static constexpr auto I1
Definition: tile_window.hpp:57
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:261
CK_TILE_DEVICE void update(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:611
constexpr CK_TILE_DEVICE tile_window_with_static_distribution(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin, const typename Base::TileDstr &tile_distribution)
Definition: tile_window.hpp:66
static constexpr index_t NumAccessPerCoord
Definition: tile_window.hpp:62
CK_TILE_DEVICE void store(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:478
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:873
constexpr CK_TILE_DEVICE tile_window_with_static_lengths()=default
constexpr CK_TILE_DEVICE tile_window_with_static_lengths(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin)
Definition: tile_window.hpp:881
Definition: tile_window_base.hpp:94
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_window_base.hpp:95
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_window_base.hpp:129
TileDstr tile_dstr_
Definition: tile_window_base.hpp:253